栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

load

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

load

1. 报错内容
/opt/tools/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1221 
   1222         if len(error_msgs) > 0:
-> 1223             raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format(
   1224                                self.__class__.__name__, "nt".join(error_msgs)))
   1225         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for UNet:
	Missing key(s) in state_dict:................(太多了)
	Unexpected key(s) in state_dict: ................(也很多)

2. 解决方案

train的时候我加了数据并行

# train
model = nn.DataParallel(model)

但是test里没加并行,加上就不会报错了

# test
model = UNet(n_channels=3, n_classes=1)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(weight_path))

问题解决了之后我还想探索一下nn.DataParallel()到底做了些什么,与不并行有什么区别。

3. 错误分析

以上的问题大概是导入的权重的键与模型本身的键无法对应。比如相同的地方,导入的权重叫module.inc.conv.conv.0.weight, 而模型这个地方的权重叫inc.conv.conv.0.weight

并且两边权重的键个数也不一样,加了并行的模型权重多了18个结尾是num_batches_tracked的键,可能是用来记录这个分支算了多少个batch(猜想)

所以给test加上数据并行的最省事的解决方案了。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/766888.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号