栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch中Resnet的代码实现解析

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch中Resnet的代码实现解析

本文介绍Pytorch对resnet的代码实现。文章从基本框架出发,由浅入深地再介绍部分细节设计,因此建议顺序阅读。

前言

Resnet是最著名的CNN网络之一,残差模型也在后续网络设计中有广泛的应用。本文主要解读Pytorch的torchvision == 0.11.1模块中对Resnet的代码实现。本文假设读者已经对CNN以及Resnet的结构有一定了解。

首先给大家推荐相关的优秀资源,《pytorch中残差网络resnet的源码解读》和李沐Resnet论文精读。

Resnet使用

直接使用Pytorch中的Resnet系列网络是很方便的,如下所示代码使用Resnet18。

import torch
import torchvision

model = torchvision.models.resnet18(pretrained=True)
data = torch.rand(1, 3, 224, 224)
prediction = model(data)

我们“F12”直接进入函数的定义。首先我们可以看的在torchvision-models内包含大量已经实现的网络,除了Resnet、Vgg等图像分类网络,还包括Faster-rcnn、Ssd等目标检测网络,Deeplabv3语义分割网络等等。

renset基本结构

在resnet.py文件内,我们可以看到共包含resnet18、resnet32等9个网络。这里我们只介绍原始论文中提到的5种网络的实现,即以下我们忽略group、dilation等相关一些参数。

我们可以看到当我们直接调用如resnet18时,首先会调用内置函数_resnet。在_resnet中,网络模型对象类型model则由ResNet类生成,在其forward函数可以看到其网络结构,如下所示。

x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x

可以看到网络依次为:7×7卷积层 - BN层(默认使用batch normalization)- 激活层 - 最大池化层 - 4个layer层-平均池化层 - 展平层 - 全连接层。

如下放入原始论文中网络结构表格(这个表格十分重要),conv2_x至conv5_x分别对应四个layer层。

" />

在ResNet类中__init__魔术方法中还可以看到对卷积中groups、dilate、norm_layer层的自定义接口,以及一些特殊的参数初始化方法。这里不过多介绍。

resnet重要函数解析

以上介绍了resnet系列网络的基本结构,下面我们针对其一些重要内部函数做详细的介绍,尤其是残差块的设计。

_resnet函数
_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
_resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
_resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
_resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)

以上列出5个resnet系列网络在调用_resnet函数时的传入的参数。我们可以看到:

  • resnet18和resnet34使用的是BasicBlock残差块,拥有较深层的50、101、152使用的是Bottleneck残差块。
  • 第三个列表参量表示四个layer层中每一层中残差块的个数。
_make_layer函数
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
                    stride: int = 1, dilate: bool = False) -> nn.Sequential:
    norm_layer = self._norm_layer
    downsample = None
    previous_dilation = self.dilation
    if dilate:
        self.dilation *= stride
        stride = 1
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            conv1x1(self.inplanes, planes * block.expansion, stride),
            norm_layer(planes * block.expansion),
        )

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                        self.base_width, previous_dilation, norm_layer))
    self.inplanes = planes * block.expansion
    for _ in range(1, blocks):
        layers.append(block(self.inplanes, planes, groups=self.groups,
                            base_width=self.base_width, dilation=self.dilation,
                            norm_layer=norm_layer))

    return nn.Sequential(*layers)

_make_layer函数的实现如上所示,作为网络中最重要的conv2_x至conv5_x的实现。我们首先关注其返回值是一个layers列表表示的网络序列。在layers中依次添加了对应数量的block,即BasicBlock或Bottleneck残差块,接下来我们来解析两种残差块。

同时我们可以发现range(1, blocks),即第一个残差块与后面的残差块是不一样的,以及block.expansion、downsample等在后面进行解释。

BasicBlock残差块
def forward(self, x: Tensor) -> Tensor:
    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        identity = self.downsample(x)

    out += identity
    out = self.relu(out)

    return out

BasicBlock类的forward函数如上所示,可以看到:

  • 最重要的残差连接部分在实现上就是一个+=符号。
  • downsample函数的目的是,当out与identity的.shape不一致时,需要对identity进行下采样,保持与out一致,以满足相加条件。
Bottleneck残差块
def forward(self, x: Tensor) -> Tensor:
    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)

    out = self.conv3(out)
    out = self.bn3(out)

    if self.downsample is not None:
        identity = self.downsample(x)

    out += identity
    out = self.relu(out)

    return out

我们可以清楚的看到Bottleneck与BasicBlock的实现区别,这里不做解释。

值得一提的是在Bottleneck中expansion值为4,BasicBlock中该值为1。该值代表了对于Bottleneck的第三个卷积层通道数增加了4倍。

再谈_make_layer函数

前面介绍了downsample函数的作用,那么何时该使用呢?我们回到_make_layer函数,看到添加downsample的判断语句:

if stride != 1 or self.inplanes != planes * block.expansion:

即当stride不为1,即卷积层图像图小尺寸发生变化时,或者通道数发生变化时。这种变化只在layer层之间发生。具体地说:

  • 何时stride不为1,即图像尺寸减半?在conv3_x(注意是自3起)至conv5_x每一个layer层的第一个残差块对图像尺寸进行减半,具体地说对于BasicBlock块是第一个卷积层,对于Bottleneck块是第二个卷积层。
  • 何时通道数发生变化?每一个layers的通道数不一致,即输入到每一个layers的第一个残差块的张量的通道数和输入到第二个残差块的张量的通道数是不一样的。(注意对于Bottleneck块,虽然第三个卷积层通道数发生变化,但是在一个layers内,每一个残差块的输出与下一个残差块的输出是不变的)。

顺理成章的,我们也可以解释为什么layers列表分为两部分,因为在添加的第一个残差块需要传入了stride, downsample参量。

总结

以上既是对Pytorch中Resnet的网络代码实现的解析。对于Resnet网络来说,只通过文中的表格就能够基本上理解Resnet的结构。通过实现代码,可以更容易发现其不同深度的Resnet的区别和联系,以及一些代码实现细节。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/664781.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号