栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch Week 3——模型创建

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch Week 3——模型创建

系列文章目录

PyTorch Week 2——Dataloader与Dataset

PyTorch Week 1

文章目录
  • 系列文章目录
  • 一、模型
  • 二、“构建”与“拼接”
    • 1. LeNet的实现
    • 2.nn.Module
  • 总结


一、模型

模型包括两个部分,一是模型创建,二是权值初始化。
模型创建就是使用卷积层、池化层、激活函数等构建网络的每一层,然后将网络层拼接起来。即两要素:“构建”、“拼接”

二、“构建”与“拼接” 1. LeNet的实现

在Pytorch中,实现模型创建功能的是nn.Module的__init__ 函数和 forward函数,

#定义LeNet类,继承nn.Module,
class LeNet(nn.Module):
    def __init__(self, classes):#模型构建的第一个要素,就是在__init__中构建子模块
    def forward(self, x):#在这里实现了拼接
#实例化
net = LeNet(classes=2)
#调用
outputs = net(inputs)

步入net,进入module的call函数,函数调用forward

def _call_impl(self, *input, **kwargs):
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
            return forward_call(*input, **kwargs)#函数调用forward

步入forward,进入LeNet类的forward函数,在这里完成计算

2.nn.Module

nn.Module包括八个有序字典去定义模型

  • parameters:储存管理nn.Parameter类
  • modules:储存管理nn.Module类

从super(LeNet, self).init()步入module,在module.py的__init__函数中定义了八个字典

self.training = True
self._parameters = OrderedDict()#1
self._buffers = OrderedDict()#2
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()#3
self._is_full_backward_hook = None
self._forward_hooks = OrderedDict()#4
self._forward_pre_hooks = OrderedDict()#5
self._state_dict_hooks = OrderedDict()#6
self._load_state_dict_pre_hooks = OrderedDict()#7
self._modules = OrderedDict()#8

步出后,变量中出现了这八个字典

接下来,步入conv2d

class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)#步入con2d

conv2d的__init__函数中,调用了父类convnd的__init__方法

super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
            False, _pair(0), groups, bias, padding_mode, **factory_kwargs)

而convnd中继承了nn.Module,调用父类Module的__init__

class _ConvNd(Module):
	def __init__():
		super(_ConvNd, self).__init__()

conv2d调用convnd,convnd调用Module,所以conv2d是一个Module类,定义了八个有序字典来管理属性

步出至LeNet后,发现_modules中已经包含conv1

接下来看如何构建conv2,步入这一行,进行实例化nn.Conv2d操作,步出后,还未完成赋值

class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)#步入

再步入,进入__setattr__函数

   def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
        def remove_from(*dicts_or_sets):
            for d in dicts_or_sets:
                if name in d:
                    if isinstance(d, dict):
                        del d[name]
                    else:
                        d.discard(name)

setattr首先判断传入的是参数还是模型,如果是模型,则赋值

if isinstance(value, Parameter):
elif params is not None and name in params:
else:
	modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
                modules[name] = value#

在modules[name] = value中,name就是’conv2’,value就是定义的卷积

以此类推,就完成了_module的定义

小结:
pytorch中网络模型的构建机制
定义的网络继承nn.Module类后,__init__中进行属性赋值是,会被setattr拦截,判断其类型,送入parameter字典或者module字典中,进行管理

总结

本节了解了模型的创建步骤,构建和拼接,__init__完成构建,forward完成拼接
nn.Modelu有八个字典去管理模型的属性
在定义模型时,setattr会拦截判断赋值的类型,根据类型赋值给parameter或者module。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/339420.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号