栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch——保存,加载模型

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch——保存,加载模型

一文梳理pytorch保存和重载模型参数攻略 查看当前模型结构与参数值
print(model.state_dict)
# 输出定义的网络结构
print(model.state_dict())
# 输出所有参数名和参数值

输出如下

 bound method Module.state_dict of Digit(
 (conv1): Conv1d(2, 10, kernel_size (5,), stride (1,))
 (conv3): Conv1d(5, 20, kernel_size (3,), stride (1,))
 (fc6): Linear(in_features 2480, out_features 500, bias True)
 (drop8): Dropout(p 0.5, inplace False)
 (fc9): Linear(in_features 500, out_features 1, bias True)
OrderedDict([( conv1.weight , tensor([[[-0.2759, 0.1526, 0.2299, -0.2617, -0.0128],
 [ 0.2975, -0.1635, -0.1661, 0.1830, 0.1413]],
 [[ 0.0064, -0.1616, -0.2967, -0.3151, 0.0642],
 [-0.0369, 0.0338, 0.2795, 0.0888, -0.2408]],
 [[ 0.2387, -0.1673, -0.2089, 0.2312, -0.2677],
 [ 0.1646, -0.0508, -0.0151, 0.3200, -0.0355]],
 [[-0.2255, 0.0793, -0.2272, -0.0198, -0.2901],
 [-0.2260, 0.0601, -0.0991, 0.0732, -0.0444]],
保存模型
torch.save(obj model.state_dict(), f ./net.pth )
# 存储路径 上级目录同为 desktop
对新的网络加载参数值

首先定义一个新的空白参数值网络

class Model(nn.Module):
 def __init__(self):
 super(Model, self).__init__()
 self.layer nn.Linear(1, 1)
 self.layer.weight nn.Parameter(torch.FloatTensor([[0]]))
 self.layer.bias nn.Parameter(torch.FloatTensor([0]))
 def forward(self, x):
 out self.layer(x)
 return out
# 该网络只有一个线性层
modeldemo Model()
print(modeldemo.state_dict())
print(modeldemo.state_dict)
# 由于未经过训练 此时的权重和偏执都为0,
OrderedDict([( layer.weight , tensor([[0.]])), ( layer.bias , tensor([0.]))])
 bound method Module.state_dict of Model(
 (layer): Linear(in_features 1, out_features 1, bias True)

可以看出未经过训练 w , b 都为0

print(model.state_dict())
# modeldemo加载 另一个网络model的权重 注意此时的两个网络应该是一样的结构
保存该模型的其他数值

创建一个字典 然后保存这个字典 字典的字段是需要的数值即可

net Digit()
Adam optim.Adam(params net.parameters(), lr 0.001, betas (0.5,0.999))
Epo 97
all_states { net :net.state_dict(), Adam : Adam.state_dict(), epoch : Epo}
torch.save(obj all_states,f ./all_states.pth )

查看已保存的内容

{ net : OrderedDict([( conv1.weight , tensor([[[-0.0992, 0.1028, 0.1915, 0.2423, -0.3130],
 [ 0.0308, -0.0206, 0.0133, -0.2522, -0.2496]],
 [[-0.2329, -0.1573, 0.3153, 0.1176, 0.0190],
 [-0.2168, 0.1106, 0.1726, 0.0559, 0.2262]],
 [[ 0.3109, -0.3043, -0.2859, -0.1401, -0.0489],
 [-0.0905, -0.0871, -0.0425, -0.1573, -0.2254]],
 [[-0.1303, -0.0006, 0.2278, -0.0243, 0.2638],
 [-0.0177, -0.0474, -0.1561, 0.2652, -0.3036]],
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/267121.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号