栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

深度学习——02pytorch卷积

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

深度学习——02pytorch卷积

一、卷积pytorch

在pytorch中有两种方式,一种是torch.nn.Conv2d()

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
参数解释
in_channels输入图像通道数
out_channels卷积产生的通道数
kernel_size卷积核尺寸
stride卷积步长,默认为1
padding填充操作,控制padding_mode的数目
padding_modepadding模式,默认为Zero-padding
dilation扩张操作:控制kernel点(卷积核点)的间距,默认值:1
groupsgroup参数的作用是控制分组卷积,默认不分组,为1组
bias添加一个可学习的偏差。默认:True

一种是torch.nn.functional.conv2d()

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

weight权重 也就是卷积核

二、卷积pytorch代码实现

在这里我使用的是第二种 torch.nn.functional.conv2d()
1、定义输入,和卷积核

input = torch.tensor([[1, 2, 0, 3, 1],
                      [0, 1, 2, 3, 1],
                      [1, 2, 1, 0, 0],
                      [5, 2, 3, 1, 1],
                      [2, 1, 0, 1, 1]])
                    
kernel = torch.tensor([[1, 2, 1],
                       [0, 1, 0],
                       [2, 1, 0]])        

2、reshape输入和卷积核

# [ batch_size, channels, height, width ]
input = torch.reshape(input,(1,1,5,5)) 
kernel = torch.reshape(kernel,(1,1,3,3))      

3、conv2d卷积

output = F.conv2d(input,kernel,stride=1)    

4、输出

print(output)  
# tensor([[[[10, 12, 12],
#           [18, 16, 16],
#           [13,  9,  3]]]])

代码:

import torch
import torch.nn.functional as F

input = torch.tensor([[1, 2, 0, 3, 1],
                      [0, 1, 2, 3, 1],
                      [1, 2, 1, 0, 0],
                      [5, 2, 3, 1, 1],
                      [2, 1, 0, 1, 1]])
kernel = torch.tensor([[1, 2, 1],
                       [0, 1, 0],
                       [2, 1, 0]])
input = torch.reshape(input,(1,1,5,5))
kernel = torch.reshape(kernel,(1,1,3,3))
output = F.conv2d(input,kernel,stride=1)
print(output)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/887148.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号