栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【随笔】使用pytorch训练Fashion mnist,注释全

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【随笔】使用pytorch训练Fashion mnist,注释全

神经网络的训练过程分为7个步骤
  1. 从训练集中得到一批数据
  2. 将数据传递给网络
  3. 计算损失(这是网络返回的预测值与真实值之间的差异) loss function(损失函数)执行第三步
  4. 计算受损失函数的梯度和网络的权值 back propagation(反向传播)执行第四步
  5. 更新权重,使用梯度减少损失 optimization algorithm(优化算法)实现第五步
  6. 重复步骤一到现在,直到一个周期完成
  7. 第七步是重复步骤1到6,以获得所期望的精确度
import torch   # torch是顶级的pytorch的包和张量库

import torch.nn as nn  # 包含了用来搭建各个层的模块(Modules),比如全连接、二维卷积、池化
import torch.nn.functional as F # 包含了常用的激活函数,如不具有可学习的参数(relu、leaky_relu、prelu、sigmoid)等

import torchvision  # torchvision是一个提供对流行的数据集、模型结构和计算机视觉的图像转换的访问的包
import torchvision.transforms as transforms # 这个接口是我们能够访问图像处理的通用转换,图像到tensor ,numpy 数组到tensor , tensor 到 图像等

import torch.optim as optim # 函数可以选择优化器
import numpy as np

train_set=torchvision.datasets.FashionMNIST(    # 从torchvision.datasets中获取数据集FashionMNIST
       root='./data/FashionMNIST'  # 训练集下载路径
        ,train=True       # 训练参数设为true,意思是数据用于训练集,在此数据集中,6万张用于训练数据,1万张用于测试数据
        ,download=False   # download设为True表示,如果数据集不在硬盘上,就执行下载,如果已经下载可以设置为False
        ,transform=transforms.Compose([
            transforms.ToTensor()   # 把图像转换为张量,用内置的Totensor
        ]))  
C:Users82325anaconda3envspytorchlibsite-packagestorchvisiondatasetsmnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..torchcsrcutilstensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
  这个警告是因为我已经下载过了,大家可以忽略

解释:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口
  2. compose类会将transforms列表里面的transform操作进行遍历,Compose把多个步骤整合到一起
  3. ToTensor:Convert a PIL Image or numpy.ndarray to tensor,将PIL Image对象转换成Tensor,会自动将【0,255】归一化至【0,1】
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

解释:

  1. argmax() 函数返回指定维度最大值的序号,dim的不同值表示不同维度,0表示沿着行查找,1表示沿着列查找,dim=1意味着找到张量中各自的最大值所在的索引
  2. 为了实现比较,使用eq() 函数,eq() 函数计算argmax输出和标签张量之间的元素相等运算,如果argmax输出中的预测类别与标签匹配,则为1,否则为0
  3. sum() 函数,可以将输出缩减为该标量值张量内的单个正确预测数
  4. 使用item()方法返回张量元素的值,即返回数目的正确预测
创建简单的cnn网络
class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__() # 直白的说super().__init__(),就是继承父类的init方法
        # 定义conv1、conv2函数的是图像卷积函数:括号内输入通道、输出通道、卷积核
        self.conv1=nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5) 
        self.conv2=nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5)
    
        # 定义fc1、fc2函数是全连接层函数,括号内输入特征、输出特征
        self.fc1=nn.Linear(in_features=12*4*4,out_features=120)  # 括号内输入特征、输出特征
        self.fc2=nn.Linear(in_features=120,out_features=60)
        self.out=nn.Linear(in_features=60,out_features=10) 
    def forward(self,t): 
        # 输入t经过卷积conv1之后,经过激活函数relu,使用2x2的窗口进行最大池化Max pooling,步长为2,然后更新到t
        t=F.max_pool2d(F.relu(self.conv1(t)),kernel_size=2,stride=2)
        
        # 输入t经过卷积conv2之后,经过激活函数relu,使用2x2的窗口进行最大池化Max pooling,然后更新到t
        t=F.max_pool2d(F.relu(self.conv2(t)),kernel_size=2,stride=2)
        
        # reshape函数将张量t变形成一维的向量形式,总特征数并不改变,为接下来的全连接作准备
        t=t.reshape(-1,12*4*4)
        # 输入t经过全连接1,再经过relu激活函数,然后更新t
        t=F.relu(self.fc1(t))
        # 输入t经过全连接2,再经过relu激活函数,更新t
        t=F.relu(self.fc2(t))

        return t

解释:

  1. super是用来解决多重继承问题的,直接用类名调用父类方法在使用单继承的时候没问题,但是如果使用多继承,会涉及到查找顺序、重复调用等种种问题
  2. super(类名,self)init()
  3. 二维卷积nn.Conv2d用于图像数据,对宽度和高度都进行卷积
  4. nn.Linear()用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量
参数的设置
  1. 卷积核、输出通道和输出特征的大小是随意确定的,通过测试和调优这些参数找到最有效的值,卷积核一般设置为 5 ∗ 5 5*5 5∗5或 3 ∗ 3 3*3 3∗3
  2. 第一个卷积层的输入通道为构成训练集的图像内部的彩色通道数量,输出层的输出特征是训练集中类的数量,其它数字一层的输入是上一层的输出,所以卷积层的所有输入通道和线性层中的输入特征都依赖于上一次的数据
  3. 卷积层转换到线性层时,必须使张量变平,这就是为什么有 12 ∗ 4 ∗ 4 12*4*4 12∗4∗4作为输入特征数,12来源于前一次输出通道数量, 4 ∗ 4 4*4 4∗4是12个输出通道的高度和宽度,从[1,28,28]输入张量开始,当张量到达第一个线性层时,高度和宽度的尺寸从 28 ∗ 28 28*28 28∗28减少到 4 ∗ 4 4*4 4∗4,这个减少是由于卷积和池化操作造成的,公式:(图像尺度-滤波器数)/步长+1,例如通过第一层卷积运算(28-5)/1+1=24,池化操作时(24-2)/2+1=12,再经过一次卷积运算和池化运算就得到了[1,12,4,4]
batch_size=1000   # 一次训练所抓取的数据样本数量
lr=0.01   # 学习率

network=Network() # 实例化类
train_loader=torch.utils.data.DataLoader(train_set,batch_size=100)  
optimizer=optim.Adam(network.parameters(),lr=0.01)  


for epoch in range(5):  # 五个周期
    
    total_loss=0  # 为追踪损失,我们将追踪正确预测的数量,所以在循环的顶部创建俩个变量,初始化为0
    total_correct=0

    # batch=next(iter(train_loader))  # 从训练集中抽取样本,将训练加载实例传递给内部函数,使用next获得下一批
    for batch in train_loader:    # 使用遍历来处理所有的批次而不是一个
        images,labels=batch # 将样本压缩到一个图像和标签中,因为处理的是一批所以对变量名使用复数形式

        preds=network(images)  # 把图像传递给网络,结果是预测张量
        loss=F.cross_entropy(preds,labels)  # 调用交叉熵损失函数,来源于nn.functional,通过预测和标签计算损失,并返回一个张量

        optimizer.zero_grad()  # 告诉优化器把梯度属性中权重的梯度归零,因为pytorch会积累梯度,所以计算梯度之前,必须确保现在没有任何梯度值
        loss.backward()  # 调用反向函数,计算梯度,backward是反向传播的简称
        optimizer.step() # 更新权重

        total_loss+=loss.item()
        total_correct+=get_num_correct(preds,labels)

    print("epoch:",epoch,"total_correct:",total_correct,"loss:",total_loss)

epoch: 0 total_correct: 47257 loss: 336.4210552871227
epoch: 1 total_correct: 51110 loss: 239.50188337266445
epoch: 2 total_correct: 51872 loss: 220.99213953316212
epoch: 3 total_correct: 52169 loss: 211.96785034239292
epoch: 4 total_correct: 52433 loss: 205.83596700429916

解释

  1. PyTorch中数据读取的一个重要接口是torch.utils.data.DataLoader,结构为torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=True)
  • dataset:下载好的训练集
  • batch_size:一次训练所抓取的数据样本数量
  • shuffle:是否将数据打乱,可以为True或者False,深度学习项目,在训练之前,一般均会对数据集做shuffle,打乱数据之间的顺序,让数据随机化,这样可以避免过拟合,例如:数据集是1,1,1,。。。。1,2,2,。。。。2,所有的1都在2前面,如果不shuffle,模型训练一段时间内只看到了1,必然会过拟合于1,一段时间内又只能看到2,必然又过拟合于2,这样的模型泛化能力必然很差。
  1. f-string 格式化字符串以 f 开头,后面跟着字符串,字符串中的表达式用大括号 {} 包起来,它会将变量或表达式计算后的值替换进去,意思就是用大括号 {} 表示被替换字段
  2. 列表推导式[item for 变量 in 列表] item:想放在列表的元素,即后面循环的元素本身
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/580425.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号