栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch 框架

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch 框架

第一章  数据加载

Dataset:提供一种方式去获取数据及其标签,并告诉我们有多少数据

Dataloader:为后面的网络提供不同的数据形式

class MyData(Dataset):               #创建一个MyData类,去继承Dataset

   def __init__(self,root_dir,label_dir):         #创建全局变量,比如数据的路径
      self.root_dir = root_dir
      self.label_dir = label_dir             #self.设置全局变量
      self.path = os.path.join(self.root_dir,self.label_dir)  #得到的是拼接的路径
      self.img_path = os.listdir(self.path)    #获得path路径下的所有文件(名字),是一个数组

   def __getiem__(self,index):       #获得图片
       img_name = self.img_path[index]   #获得数组中的一个(文件名),不是路径
       img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)#再拼接上文件名
       label = self.label_dir
       return img,label

   def __len__(self):
       return len(self.img_path)    #返回一个长度

#使用
root_dir = "database/train"
data_label_dir = "data"
data_dataset = MyDate(root_dir,data_label_dir)
img,label = data_dataset[0]    #第一张数据
img.show()     #放出来
       
第二章  tensorboard的使用 

主要用于看loss的变化

writer = SummaryWriter("logs")    #把文件存储在logs文件夹下

#有三种主要的使用
writer.add_image()    #用来把图片显示在ten里
writer.add_scalar()   #显示函数
writer.close()

#在终端输入命令,复制地址可打开。
#tensorboard --logdir=logs    其中logdir=事件文件所在文件夹名
#tensorboard --logdir=logs --port=6007    指定端口

其中add_image()读取数据的类型必须是 torch.Tensor,  numpy.array,  或者是string/blobname类型,故 要进行数据类型转换。

opencv-python是最常用来打开numpy类型的包

import numpy as np
from PIL import Image

image_path = '地址'
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)      #先获得PIL数据类型
print(type(img_array))
第三章   transforms

用来对图像进行变换,也就是输入一个特定格式的图片,经过transforms的函数后输出我们想要的图片结果

totensor数据类型:

img_path = '路径'
img = Image.open(img_path) 
print(img)   #PIL类型的图片

tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)   #将PIL转换为tensor数据类型
 

Resize改变尺寸:

t_resize = transforms.Resize((512,512))
img_size = t_resize(img)
第四章  数据集的加载

转换为tensor数据类型

import torchvision

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

五   神经网络的搭建 

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):  #相当于nn.Module是一个网络框架,我们对其一部分进行更改
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):      #神经网络经过forward得到一个输出(前向传播)
        x = F.relu(self.conv1(x))    # x经过一次卷积conv1,再经过一次非线性relu
        return F.relu(self.conv2(x))    #得到的x再经过一次conv2再经过一次relu

卷积:用卷积核在输入图像上对应相乘再相加。

import torch
input = torch.tensor([[1,2,0,3,1],    #[[表示是二维矩阵   #输入图像
                      [0,1,2,3,1],
                      [1,2,1,0,0],
                      [5,2,3,1,1],
                      [2,1,0,1,1]])

kernel = torch.tensor([[1,2,1],      #卷积核
                       [0,1,0],
                       [2,1,0]])

print(input.shape)              # torch.Size([5,5])
input = torch.reshape(input,(1,1,5,5))   # 1-bachsize为1;1- 平面所以通道为1;(5,5)H,W
kernel = torch.reshape(kernel,(1,1,3,3))

#卷积
import torch.nn.functional as F

output = F.conv2d(input,kernel,stride = 1)     #stride = 1 走一步
print(output)   #结果就是对应相乘相加得到的矩阵

 conv2d的输入要求:input(minibatch,in_channels,H,W)四个参数,而图片的shape只输出(H,W)因此采用reshape函数

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/757189.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号