栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【Pytorch】nn.Linear,nn.Conv

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【Pytorch】nn.Linear,nn.Conv

nn.Linear

nn.Conv1d

当nn.Conv1d的kernel_size=1时,效果与nn.Linear相同,不过输入数据格式不同:
https://blog.csdn.net/l1076604169/article/details/107170146

import torch


def count_parameters(model):
    """Count the number of parameters in a model."""
    return sum([p.numel() for p in model.parameters()])


conv = torch.nn.Conv1d(3, 32, kernel_size=1)
print(count_parameters(conv))
# 128

linear = torch.nn.Linear(3, 32)
print(count_parameters(linear))
# 128

print(conv.weight.shape)
# torch.Size([32, 3, 1])
print(linear.weight.shape)
# torch.Size([32, 3])

# use same initialization
linear.weight = torch.nn.Parameter(conv.weight.squeeze(2))
linear.bias = torch.nn.Parameter(conv.bias)

tensor = torch.randn(128, 256, 3)   # [batch, feature_num,feature_size]
permuted_tensor = tensor.permute(0, 2, 1).clone().contiguous()  # [batch, feature_size,feature_num]

out_linear = linear(tensor)
print(out_linear.mean())
# tensor(0.0344, grad_fn=)
print(out_linear.shape)
# torch.Size([128, 256, 32])


out_conv = conv(permuted_tensor)
print(out_conv.mean())
# tensor(0.0344, grad_fn=)
print(out_conv.shape)
# torch.Size([128, 32, 256])


nn.Conv2d

nn.Conv3d

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/1038220.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号