栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch保存和加载模型_pytorch载入模型?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch保存和加载模型_pytorch载入模型?

可能会遇到有一个多gpu的训练后保存的模型,但是后续要在单gpu的机子上使用,在torxh.load时报错。
这是因为nn.DataParallel会在模型参数结构前面加一个module.
比如你是这么save的

state = {'epoch': epoch, 'state_dict': self.model.state_dict(),
                     'optimizer': self.optimizer.state_dict(),
                     'info': self.info,
                     'scheduler': self.scheduler.state_dict()
                     }
 
w_dict = torch.load(path + "/pretrain_model_name", map_location=lambda storage, loc: storage)  # 权重是 多gpu的,去掉module字样
# print( "original dict:--", w_dict)

new_state_dict= OrderedDict()
for  k,v in w_dict['state_dict'].items():  # 字典 'state_dict':
        namekey= k[7:] if k.startswith('module.') else  k
        new_state_dict[namekey]= v
print("new_state_dict---: {}".format(new_state_dict) )
 model.load_state_dict(new_state_dict)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/783503.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号