栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

神经网络的一些小函数

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

神经网络的一些小函数

这里写目录标题
    • isinstance(object, classinfo)
    • model.modules()和model.children()
    • torch.flatten(x, 1)
    • nn.ModuleList 和 nn.Sequential
    • torch.clamp、torch.unsqueeze

isinstance(object, classinfo)

如果对象的类型与参数二的类型(classinfo)相同则返回 True,否则返回 False

a = 2
isinstance (a,int)
True
import torch.nn as nn
avgpool = nn.AdaptiveAvgPool2d((1, 1))
if isinstance(avgpool , nn.AdaptiveAvgPool2d):
	print('yes
m=nn.Sequential(nn.Conv2d(3, 16, kernel_size=1, stride=2, bias=False),
                nn.Conv2d(16, 16, kernel_size=1, stride=2, bias=False))
    if type(m) is nn.Sequential:
        print('yes')
    if isinstance(m,nn.Sequential):
        print('yes2')
yes
yes2
model.modules()和model.children()

model.modules()和model.children()均为迭代器,model.modules()会遍历model中所有的子层,而model.children()仅会遍历当前层

# model.modules()类似于 [[1, 2], 3],其遍历结果为:
[[1, 2], 3], [1, 2], 1, 2, 3
 
# model.children()类似于 [[1, 2], 3],其遍历结果为:
[1, 2], 3
torch.flatten(x, 1)

torch.flatten(t, start_dim=0, end_dim=-1) 的实现原理如下。假设类型为 torch.tensor 的张量 t 的形状如下所示:(2,4,3,5,6),则 orch.flatten(t, 1, 3).shape 的结果为 (2, 60, 6)。将索引为 start_dim 和 end_dim 之间(包括该位置)的数量相乘,其余位置不变。因为默认 start_dim=0,end_dim=-1,所以 torch.flatten(t) 返回只有一维的数据。
假设t的shape为(2,3,4,5,6)
torch.flatten(t, 1,3)结果为(2,60,6)

nn.ModuleList 和 nn.Sequential

加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的

linears = nn.ModuleList([nn.Linear(10,10) for i in range(2)])

nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言,可以在forward函数里指定顺序

class net3(nn.Module):
    def __init__(self):
        super(net3, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
    def forward(self, x):
        x = self.linears[2](x)
        x = self.linears[0](x)
        x = self.linears[1](x) 
        return x

nn.Sequential,不同于 nn.ModuleList,它已经实现了内部的 forward 函数,而且里面的模块必须是按照顺序进行排列的,所以直接使用 nn.Sequential 不用写 forward 函数,因为它内部已经帮你写好了。一般情况下 nn.Sequential 的用法是来组成卷积块 (block),然后像拼积木一样把不同的 block 拼成整个网络,让代码更简洁,更加结构化
两种初始化方式

# Example of using Sequential
model1 = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
print(model1)
# Sequential(
#   (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#   (1): ReLU()
#   (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#   (3): ReLU()
# )

# Example of using Sequential with OrderedDict
import collections
model2 = nn.Sequential(collections.OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
print(model2)
# Sequential(
#   (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#   (relu1): ReLU()
#   (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#   (relu2): ReLU()
# )

参考知乎

torch.clamp、torch.unsqueeze
img1=torch.clamp(img1,0,255)
x=torch.unsqueeze(x,dim=0)#增加维度

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/268580.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号