栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【深度学习代码】x = x.view(x.size(0), -1)理解

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【深度学习代码】x = x.view(x.size(0), -1)理解

x = x.view(x.size(0), -1) 一般出现在前向传播过程中卷积层与全连接层的交替的位置。下面代码为Lenet网络训练cifar10数据集的情况。

class LeNet(nn.Module):
def init(self):
super(LeNet, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(1655, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    out = F.relu(self.conv1(x))
    out = F.max_pool2d(out, 2)
    out = F.relu(self.conv2(out))
    out = F.max_pool2d(out, 2)
    out = out.view(out.size(0), -1)
    out = F.relu(self.fc1(out))
    out = F.relu(self.fc2(out))
    out = self.fc3(out)
    return out

out = out.view(out.size(0), -1)中out.size(0) 的out大小为[batchsize,16,5,5],因此out.size(0) = batchsize;如果为out.size(1), 则 = 16。我自己训练时batchsize设置为32,所以out.size(0) = 32。
out.view(out.size(0),-1)则是把out转化为行数为out.size(0) ,列数相应自动生成(对于我的代码列数则为1655 = 400)。

out = out.view(out.size(0), -1)函数功能相当于torch.flatten()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/1012619.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号