栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch查看模型某一层的参数以及参数初始化

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch查看模型某一层的参数以及参数初始化

目录
  • 查看模型特定层的参数
    • 对于nn.Sequencial模块
    • 对于非nn.Sequencial模块(继承nn.Module类自定义一个模型类)
  • 参数初始化方法

查看模型特定层的参数 对于nn.Sequencial模块
  1. 除了可以通过下面这种方法访问模型中特定层的参数外, 还可以以索引的形式访问某一层特定层
import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 2))
for name, param in net.named_parameters():
    print(name, param.shape)
print('-'*50)
print('0.weight: ', net.state_dict()['0.weight'].shape)


2. 索引方式

import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 2))
for name, param in net[0].named_parameters():
    print(name, param.shape)

对于非nn.Sequencial模块(继承nn.Module类自定义一个模型类)

只能用字典方式
先打印出每层的参数,查询出需要查询特定层的特定参数的名字
再使用model.state_dict()得到对应的参数

import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 2))
for name, param in net.named_parameters():          #打印出整个网络的所有参数名称
    print(name, param.shape) 
print('-'*50)
print('0.weight: ', net.state_dict()['0.weight'].shape)          #打印特定的层和参数
参数初始化方法
def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)

根据需要选择对应的初始化方法
调用的时候仅需init_weights(net)即可。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/664725.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号