- 预训练后的权重如何导入另一个网络模型?预训练对应的网络模型A与未训练的网络模型结构B不对应?
2.1 两个网络模型A和B只有部分对应
2.2 集合关系上A属于B
2.3 集合关系上B属于A
方案 PyTorch文档
torch.nn.modules.module.Module def load_state_dict(self,
state_dict: Dict[str, Tensor] | OrderedDict[str, Tensor],
strict: bool = …) -> None说明:将 state_dict 中的参数和缓冲区复制到此模块及其后代中。
如果 strict 为 True,则 state_dict 的键必须与此模块的torch.nn.Module.state_dict 函数返回的键完全匹配 参数
state_dict – 包含参数和持久缓冲区的字典。
strict – 是否严格强制:
attr:state_dict 中的键与该模块的 :meth:~torch.nn.Module.state_dict 函数返回的键匹配。 默认值:“真” 返回值:
missing_keys 是包含缺失键的 str 列表unexpected_keys 是包含意外键的 str 列表 模型对应,完全导入
# demo1 完全加载权重 model = NET1() state_dict = model.state_dict() weights = torch.load(weights_path)['model_state_dict'] #读取预训练模型权重 model.load_state_dict(weights)模型不完全对应
此一种情况经常出现在要修改预训练网络模型中某些层时,可能增加若干层,可能减少若干层,或上述两种情况皆有。
只有部分对应
两个模型中有部分是对应的,此种情况建议使用PyTorch中的load_state_dict所提供的参数:strict
将strict设置为False,可以在两个模型不同的情况下,仅加载相同键值部分。(保证各层的名字相同)
# demo2 model = NET2() state_dict = model.state_dict() weights = torch.load(weights_path)['model_state_dict'] #读取预训练模型权重 model.load_state_dict(weights, strict=False) #strictA属于B
此种情况常见于,在网上download别人的预训练模型后,需要根据自己的任务,添加若干个层,而其他层保持不变。
# demo3 *****待测试B属于A
此种情况常见于从网上download别人的预训练模型后,因为某些限制,需要对模型进行精简,只删除若干个层,其他层保持不变。
# demo4 *****待测试



