print(model.state_dict) # 输出定义的网络结构 print(model.state_dict()) # 输出所有参数名和参数值
输出如下
bound method Module.state_dict of Digit( (conv1): Conv1d(2, 10, kernel_size (5,), stride (1,)) (conv3): Conv1d(5, 20, kernel_size (3,), stride (1,)) (fc6): Linear(in_features 2480, out_features 500, bias True) (drop8): Dropout(p 0.5, inplace False) (fc9): Linear(in_features 500, out_features 1, bias True) OrderedDict([( conv1.weight , tensor([[[-0.2759, 0.1526, 0.2299, -0.2617, -0.0128], [ 0.2975, -0.1635, -0.1661, 0.1830, 0.1413]], [[ 0.0064, -0.1616, -0.2967, -0.3151, 0.0642], [-0.0369, 0.0338, 0.2795, 0.0888, -0.2408]], [[ 0.2387, -0.1673, -0.2089, 0.2312, -0.2677], [ 0.1646, -0.0508, -0.0151, 0.3200, -0.0355]], [[-0.2255, 0.0793, -0.2272, -0.0198, -0.2901], [-0.2260, 0.0601, -0.0991, 0.0732, -0.0444]],保存模型
torch.save(obj model.state_dict(), f ./net.pth ) # 存储路径 上级目录同为 desktop对新的网络加载参数值
首先定义一个新的空白参数值网络
class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.layer nn.Linear(1, 1) self.layer.weight nn.Parameter(torch.FloatTensor([[0]])) self.layer.bias nn.Parameter(torch.FloatTensor([0])) def forward(self, x): out self.layer(x) return out # 该网络只有一个线性层
modeldemo Model() print(modeldemo.state_dict()) print(modeldemo.state_dict) # 由于未经过训练 此时的权重和偏执都为0,
OrderedDict([( layer.weight , tensor([[0.]])), ( layer.bias , tensor([0.]))]) bound method Module.state_dict of Model( (layer): Linear(in_features 1, out_features 1, bias True)
可以看出未经过训练 w , b 都为0
print(model.state_dict()) # modeldemo加载 另一个网络model的权重 注意此时的两个网络应该是一样的结构保存该模型的其他数值
创建一个字典 然后保存这个字典 字典的字段是需要的数值即可
net Digit()
Adam optim.Adam(params net.parameters(), lr 0.001, betas (0.5,0.999))
Epo 97
all_states { net :net.state_dict(), Adam : Adam.state_dict(), epoch : Epo}
torch.save(obj all_states,f ./all_states.pth )
查看已保存的内容
{ net : OrderedDict([( conv1.weight , tensor([[[-0.0992, 0.1028, 0.1915, 0.2423, -0.3130],
[ 0.0308, -0.0206, 0.0133, -0.2522, -0.2496]],
[[-0.2329, -0.1573, 0.3153, 0.1176, 0.0190],
[-0.2168, 0.1106, 0.1726, 0.0559, 0.2262]],
[[ 0.3109, -0.3043, -0.2859, -0.1401, -0.0489],
[-0.0905, -0.0871, -0.0425, -0.1573, -0.2254]],
[[-0.1303, -0.0006, 0.2278, -0.0243, 0.2638],
[-0.0177, -0.0474, -0.1561, 0.2652, -0.3036]],


