栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch加载模型错误 RuntimeError: Error(s) in loading state

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch加载模型错误 RuntimeError: Error(s) in loading state

模型在保存时侯以键对值保存,同时在加载时根据现在网络的键值查找模型对应的键值,然后加载。一般报错是因为模型和网络的键值不匹配。

1、最常见的问题是键值多了或者少了 module.

此种情况是模型在DataParallel或者DDP训练后保存的键值有module. ,对应的网络的键值则没有module.

1)可以通过:

model = nn.DataParallel(model)

将模型的键值加上module.

2) 也可以通过遍历模型的键对值修改键值。

   如:加载模型时删除多余的module.  代码如下

state_dict = torch.load(load_path)
for key, param in state_dict.items():
    if key.startswith('module.'):        #键值包含‘module.’ 则删除 
        state_dict[key[7:]] = param          
        state_dict.pop(key)
net.load_state_dict(state_dict)
        
2、详解load_state_dict(state_dict, False)的False参数

很多教程说名字不匹配直接添加False参数即可,但是这里需要注意一个大坑。

如果模型的键值和网络的键值完全不匹配,那么模型就没有加载预训练参数,虽然不再报错。

该False参数作用在于 非严格匹配加载模型,可以下面几种情况进行分析。

1)模型包含网络的部分参数

比如说模型是resnet101模型,你现在的网络是resnet50。再假设resnet50的参数名包含在resnet101的参数中,那么直接使用False会为你的网络resnet50加载键值相同的参数。这样就避免了对resnet101的每个键对值进行循环匹配,看是否是resnet50需要的。

2)模型完全不包含网络的参数

情况如1,模型有100个参数,都包含'module.' ,网络也有100个参数,都没有'module.' 。这种情况下如果参数设置为False,会发现没有任何键值能匹配上,因此网络就不会加载任何参数。

3)再介绍一个False使用场景

比如蒸馏网络PISR中,教师网络包含Encoder和Decoder两部分,学生网络由其中的Decoder部分组成,所以在训练学生网络时,如果要加载教师网络保存的预训练模型,设置False会自动识别Decoder部分键值相同,然后加载。

综上,设置False参数后依旧是按照键值查询加载参数的,有多少键值匹配,就加载多少模型的参数。

 

3、只要参数尺寸相同,就能加载

比如说我有一个10层网络的模型,还有一个3层的网络。我想把其中第9层的参数加载到现在网络的1层。如果参数的尺寸相同,就可以遍历键对值。将参数加载到想要的键值中。

state_dict = torch.load(load_path)
new_state_dict = []
for key, param in state_dict.items():
    if 'conv9' in key:        # 如果找到conv9对应的参数,将其键值替换为网络的键
        new_state_dict[key.replace('conv9', 'conv1')] = param   
net.load_state_dict(new_state_dict)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/324032.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号