栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch深度学习(6)池化层

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch深度学习(6)池化层

Pooling Layers  池化层 1、种类
  • nn.MaxPool1d  下采样池化
  • nn.MaxPool2d
  • nn.MaxPool3d
  • nn.MaxUnpool1d  上采样池化
  • nn.MaxUnpool2d
  • nn.MaxUnpool3d
  • nn.AvgPool1d  平均池化
  • nn.AvgPool2d
  • nn.AvgPool3d
  • nn.AdaptiveMaxPool1d  自适应最大池化
  • nn.AdaptiveMaxPool2d
  • nn.AdaptiveMaxPool3d
  • nn.AdaptiveAvgPool1d  自适应平均池化
  • nn.AdaptiveAvgPool2d
  • nn.AdaptiveAvgPool3d

2、参数

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

kernel_size:取最大值的窗口,如为3,则生成3×3的窗口

stride:窗口步进大小,默认值为kernel_size

padding:预留空白区域大小

dilation:核与核之间空格,又称空洞卷积

 ceil_mode:为True时,使用的是ceil模式而不是floor模式,例如:2.31,Floor为2,Ceiling为3

ceil为True,保留,为False,不保留

池化函数使用某一位置的相邻输出的总体统计特征来代替网络在该位置的输出,本质是 降采样,可以大幅减少网络的参数量

3、池化作用

(马赛克)

池化一般跟在卷积层后,卷积层用来提取特征,一般有相应特征的位置是比较大的数字,最大池化可以提出来这一部分有相应特征的信息

池化操作保存纹理

4、池化结果

原图片

 池化后图片

5、具体代码
import ssl

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn import MaxPool2d

ssl._create_default_https_context = ssl._create_unverified_context

dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False,
                                       transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64)

input_data = torch.tensor([[1, 2, 0, 3, 1],
                           [0, 1, 2, 3, 1],
                           [1, 2, 1, 0, 0],
                           [5, 2, 3, 1, 1],
                           [2, 1, 0, 1, 1]], dtype=torch.float32)
input_data = torch.reshape(input_data, (-1, 1, 5, 5))  # -1为自动调节前面的维度

class TestMaxPool(nn.Module):
    def __init__(self):
        super(TestMaxPool, self).__init__()
        self.maxPool1 = MaxPool2d(kernel_size=3, ceil_mode=False)  # ceil 天花板,向上取整  floor 向下取整

    def forward(self, in_data):
        output_data = self.maxPool1(in_data)
        return output_data


txp = TestMaxPool()
result = txp(input_data)
print(result)

writer = SummaryWriter("logs")

step = 0
for data in dataloader:
    imgs, target = data
    writer.add_image("input_img", imgs, step, dataformats="NCHW")

    output = txp(imgs)
    writer.add_image("output_img", output, step, dataformats="NCHW")
    step = step + 1

writer.close()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/325780.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号