栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch保存checkpoint(检查点):通常在训练模型的过程中,每隔一段时间就将训练模型信息保存一次【包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复】

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch保存checkpoint(检查点):通常在训练模型的过程中,每隔一段时间就将训练模型信息保存一次【包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复】

通常在训练模型的过程中,可能会遭遇断电、断网的尴尬,一旦出现这种情况,先前训练的模型就白费了,又得重头开始训练。因此每隔一段时间就将训练模型信息保存一次很有必要。而这些信息不光包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复训练。

state = {
    'epoch' : epoch + 1,  #保存当前的迭代次数
    'state_dict' : model.state_dict(), #保存模型参数
    'optimizer' : optimizer.state_dict(), #保存优化器参数
    ...,      #其余一些想保持的参数都可以添加进来
    ...,
}

torch.save(state, 'checkpoint.pth.tar')     #将state中的信息保存到checkpoint.pth.tar
#Pytorch 约定使用.tar格式来保存这些检查点


#当想恢复训练时
checkpoint = torch.load('checkpoint.pth.tar')
epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])   #加载模型的参数
optimizer.load_state_dict(checkpoint['optimizer']) #加载优化器的参数



参考资料:
pytorch如何保存模型?

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/871047.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号