通常在训练模型的过程中,可能会遭遇断电、断网的尴尬,一旦出现这种情况,先前训练的模型就白费了,又得重头开始训练。因此每隔一段时间就将训练模型信息保存一次很有必要。而这些信息不光包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复训练。
state = {
'epoch' : epoch + 1, #保存当前的迭代次数
'state_dict' : model.state_dict(), #保存模型参数
'optimizer' : optimizer.state_dict(), #保存优化器参数
..., #其余一些想保持的参数都可以添加进来
...,
}
torch.save(state, 'checkpoint.pth.tar') #将state中的信息保存到checkpoint.pth.tar
#Pytorch 约定使用.tar格式来保存这些检查点
#当想恢复训练时
checkpoint = torch.load('checkpoint.pth.tar')
epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict']) #加载模型的参数
optimizer.load_state_dict(checkpoint['optimizer']) #加载优化器的参数
参考资料:
pytorch如何保存模型?



