栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【Pytorch实现】——LeNet网络

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【Pytorch实现】——LeNet网络

【Pytorch实现】——LeNet网络
import torch
from torch import nn
from d2l import torch as d2l

# 继承nn.Module类,实现__init__和forward
class Reshape(nn.Module):
  def forward(self,x):
    # 将图片形状变成BxCxHxW
    return x.reshape(-1,1,28,28)

# nn.Sequential可以包装任何继承nn.Module实现的实例
net = nn.Sequential(
    Reshape(),
    # F = 2 * P + 1 特征图尺寸不变
    nn.Conv2d(1,6,kernel_size=5,padding=2),
    nn.Sigmoid(),
    # 平均池化 特征图尺寸减小为原来的一半
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2,stride=2),
    # nn.Flatten操作仅保留第一个维度将其他维度打平
    nn.Flatten(),
    nn.Linear(16*5*5,120),
    nn.Sigmoid(),
    nn.Linear(120,84),
    nn.Sigmoid(),
    nn.Linear(84,10)
)

X = torch.randn(size=(4,1,28,28),dtype=torch.float32)
# nn.Sequential()包装后相当于将不同层存放到一个list中
# 迭代打印经过每一层后特征图尺寸变化(get)
for layer in net:
  X = layer(X)
  print(layer.__class__.__name__, 'output shape: t', X.shape)
Reshape output shape: 	 torch.Size([4, 1, 28, 28])
Conv2d output shape: 	 torch.Size([4, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([4, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([4, 6, 14, 14])
Conv2d output shape: 	 torch.Size([4, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([4, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([4, 16, 5, 5])
Flatten output shape: 	 torch.Size([4, 400])
Linear output shape: 	 torch.Size([4, 120])
Sigmoid output shape: 	 torch.Size([4, 120])
Linear output shape: 	 torch.Size([4, 84])
Sigmoid output shape: 	 torch.Size([4, 84])
Linear output shape: 	 torch.Size([4, 10])
  • 上面我们是手动打印每一层的名称和参数,其实Pytorch中提供了一种更加便捷的工具summary
import torch
from torch import nn
from d2l import torch as d2l
from torchsummary import summary

# 继承nn.Module类,实现__init__和forward
class Reshape(nn.Module):
  def forward(self,x):
    # 将图片形状变成BxCxHxW
    return x.reshape(-1,1,28,28)

# nn.Sequential可以包装任何继承nn.Module实现的实例
net = nn.Sequential(
    Reshape(),
    # F = 2 * P + 1 特征图尺寸不变
    nn.Conv2d(1,6,kernel_size=5,padding=2),
    nn.Sigmoid(),
    # 平均池化 特征图尺寸减小为原来的一半
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2,stride=2),
    # nn.Flatten操作仅保留第一个维度将其他维度打平
    nn.Flatten(),
    nn.Linear(16*5*5,120),
    nn.Sigmoid(),
    nn.Linear(120,84),
    nn.Sigmoid(),
    nn.Linear(84,10)
)

X = torch.randn(size=(4,1,28,28),dtype=torch.float32)
summary(net,(4,1,28,28))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
           Reshape-1            [-1, 1, 28, 28]               0
            Conv2d-2            [-1, 6, 28, 28]             156
           Sigmoid-3            [-1, 6, 28, 28]               0
         AvgPool2d-4            [-1, 6, 14, 14]               0
            Conv2d-5           [-1, 16, 10, 10]           2,416
           Sigmoid-6           [-1, 16, 10, 10]               0
         AvgPool2d-7             [-1, 16, 5, 5]               0
           Flatten-8                  [-1, 400]               0
            Linear-9                  [-1, 120]          48,120
          Sigmoid-10                  [-1, 120]               0
           Linear-11                   [-1, 84]          10,164
          Sigmoid-12                   [-1, 84]               0
           Linear-13                   [-1, 10]             850
================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.12
Params size (MB): 0.24
Estimated Total Size (MB): 0.37
----------------------------------------------------------------
  • 调用summary函数后输出内容包括:
  • 名称Layer
  • 输出特征图大小Output Shape (注意输出大小的Batch维度为-1)
  • 参数量 Param
  • 总参数量 Total params
  • 可训练参数量 Trainable params
  • 不可训练参数量 Non-trainable params
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/303988.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号