当想进行迁移学习或者训练一个更复杂模型的时候可能会需要利用一个其他模型训练好的参数,两者的参数列表可能不同,要导入其实非常简单:
- 存储 modelA 的参数
torch.save(modelA.state_dict(), PATH)
- 加载到 modelB 上:
modelB = TheModelBClass(*args, **kwargs) modelB.load_state_dict(torch.load(PATH), strict=False)
load_state_dict 的strict参数设为False即可.
参考: https://pytorch.org/tutorials/beginner/saving_loading_models.html#warmstarting-model-using-parameters-from-a-different-model



