栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

lenet5鍗风Н绁炵粡缃戠粶pytorch_lenet-5浠g爜?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

lenet5鍗风Н绁炵粡缃戠粶pytorch_lenet-5浠g爜?

from torch import nn
import torch
from torch.nn import functional as F

class LeNet(nn.Module):
    def __init__(self, num_class = 10):
        #num_class为需要分到的类别数
        super().__init__()
        #输入像素大小为1*28*28
        self.features = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size = 5, padding = 2),#输出为6*28*28
            nn.AvgPool2d(kernel_size= 2, stride= 2),#输出为6*14*14,此处也可用MaxPool2d
            nn.Conv2d(6, 16, kernel_size = 5),#输出为16*10*10
            nn.ReLU(),#论文中为sigmoid,但极易出现梯度消失
            nn.AvgPool2d(kernel_size= 2, stride= 2),#输出为16*5*5
            nn.Flatten()#将通道及像素进行合并,方便进一步使用全连接层
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(), #论文中同样为sigmoid
            nn.Linear(120, 84),
            nn.Linear(84, 10))
    def forward(self, x):
            x = self.features(x)
            x = self.classfier(x)

网络结构:

LeNet(
  (features): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (3): ReLU()
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): Linear(in_features=84, out_features=10, bias=True)
  )
)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/786715.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号