栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch加载部分预训练模型_pytorch学习007- -预训练中的权重加载(完全导入,部分导入)?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch加载部分预训练模型_pytorch学习007- -预训练中的权重加载(完全导入,部分导入)?

问题
    预训练后的权重如何导入另一个网络模型?预训练对应的网络模型A与未训练的网络模型结构B不对应?
    2.1 两个网络模型A和B只有部分对应
    2.2 集合关系上A属于B
    2.3 集合关系上B属于A

方案 PyTorch文档

torch.nn.modules.module.Module def load_state_dict(self,
state_dict: Dict[str, Tensor] | OrderedDict[str, Tensor],
strict: bool = …) -> None说明:将 state_dict 中的参数和缓冲区复制到此模块及其后代中。

如果 strict 为 True,则 state_dict 的键必须与此模块的torch.nn.Module.state_dict 函数返回的键完全匹配 参数
state_dict – 包含参数和持久缓冲区的字典。
strict – 是否严格强制:

attr:state_dict 中的键与该模块的 :meth:~torch.nn.Module.state_dict 函数返回的键匹配。 默认值:“真” 返回值:

missing_keys 是包含缺失键的 str 列表unexpected_keys 是包含意外键的 str 列表 模型对应,完全导入

# demo1 完全加载权重
model = NET1()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict'] #读取预训练模型权重
model.load_state_dict(weights)
模型不完全对应

此一种情况经常出现在要修改预训练网络模型中某些层时,可能增加若干层,可能减少若干层,或上述两种情况皆有。

只有部分对应


两个模型中有部分是对应的,此种情况建议使用PyTorch中的load_state_dict所提供的参数:strict
将strict设置为False,可以在两个模型不同的情况下,仅加载相同键值部分。(保证各层的名字相同)

# demo2
model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict']	#读取预训练模型权重
model.load_state_dict(weights, strict=False)	#strict

A属于B


此种情况常见于,在网上download别人的预训练模型后,需要根据自己的任务,添加若干个层,而其他层保持不变。

# demo3
*****待测试
B属于A


此种情况常见于从网上download别人的预训练模型后,因为某些限制,需要对模型进行精简,只删除若干个层,其他层保持不变。

# demo4
*****待测试
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/786890.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号