/opt/tools/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1221
1222 if len(error_msgs) > 0:
-> 1223 raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format(
1224 self.__class__.__name__, "nt".join(error_msgs)))
1225 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for UNet:
Missing key(s) in state_dict:................(太多了)
Unexpected key(s) in state_dict: ................(也很多)
2. 解决方案
train的时候我加了数据并行
# train model = nn.DataParallel(model)
但是test里没加并行,加上就不会报错了
# test model = UNet(n_channels=3, n_classes=1) model = nn.DataParallel(model) model.load_state_dict(torch.load(weight_path))
问题解决了之后我还想探索一下nn.DataParallel()到底做了些什么,与不并行有什么区别。
3. 错误分析以上的问题大概是导入的权重的键与模型本身的键无法对应。比如相同的地方,导入的权重叫module.inc.conv.conv.0.weight, 而模型这个地方的权重叫inc.conv.conv.0.weight。
并且两边权重的键个数也不一样,加了并行的模型权重多了18个结尾是num_batches_tracked的键,可能是用来记录这个分支算了多少个batch(猜想)
所以给test加上数据并行的最省事的解决方案了。



