device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
filepath = 'model.dat'
# 保存参数
torch.save(model.state_dict(), filepath)
# 加载模型参数 , map_location: 把数据加载到哪个device(GPU或CPU)
model.load_state_dict(torch.load(filepath, map_location=device)
方法二:保存模型参数的同时保存其他训练相关状态,以便再次加载先前状态进行训练
# 保存相关状态
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
# 加载
state = torch.load(filepath)
epoch = state['epoch']
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
...
方法三: 保存整个模型(一般不建议使用)
# 保存 torch.save(model, filepath) # 加载 model = torch.load(filepath)
参考链接: https://stackoverflow.com/a/49078976.



