栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch 之 保存不同形式的预训练模型

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch 之 保存不同形式的预训练模型

注意,后缀.pt和.pth似乎没什么区别

保存时即可以保存整个模型也可以只保存参数,还可以构建新字典重新保存,这也就对应了在读取时需要做不同的处理,我们在加载的时候load_state_dict函数的参数就是OrderedDict类型的参数,这里给出了四种不同保存方式及其读取获得OrderedDict的方式。
1.保存

# coding=gbk
import torch 
import torch.nn as nn
class MLP_(nn.Module):
    def __init__(self):
        super(MLP_, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP_()

#保存整个模型

torch.save(net, 'a1.pt')

all_model = {'model':net} #为模型部分添加键值,这样如果想要保存优化器参数的,可以向字典中加入新值
torch.save(all_model, 'a2.pt')
#只保存参数

torch.save(net.state_dict(),'a3.pt')
all_states = {'state_dict': net.state_dict()} #为模型参数部分添加键值,这样如果想要保存优化器参数的,可以向字典中加入新值
torch.save(all_states, 'a4.pt')

2.加载

# coding=gbk
import torch
from save import MLP_
if __name__ == "__main__":
    with torch.no_grad():
   
        a1 = 'a1.pt'
        a2 = 'a2.pt'
        a3 = 'a3.pt'
        a4 = 'a4.pt'
        a1_ = torch.load(a1)
        print(a1_.state_dict())
        a2_ = torch.load(a2)['model'] #通过键值选取对应值
        print(a2_.state_dict())
        a3_ = torch.load(a3)
        print(a3_)
        a4_ = torch.load(a4)['state_dict'] #通过键值选取对应值
        print(a4_)

参考:https://zhuanlan.zhihu.com/p/94971100

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/329388.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号