经常看到(.pt,.pth,.pkl)的pytorch模型文件,并不是格式上不同,只是后缀不同; torch.save函数保存模型文件时,因人而异; 重点在于保存模型的方式不同,需要注意1.1、只保存模型参数,不保存模型结构
保存:
# 模型权重参数,不保存模型结构,速度快,占空间少
torch.save(model.state_dict(), "mymodel.pth")
调用:
# 这里需要重新模型结构,My_model
model = My_model(*args, **kwargs)
# 这里根据模型结构,调用存储的模型参数
model.load_state_dict(torch.load(mymodel.pth))
model.eval()
1.2、保存整个模型,包括模型结构和模型参数
保存:
# 保存整个model的状态
torch.save(model, mymodel.pth)
调用:
# 这里已经不需要重构模型结构了,直接load就可以
model=torch.load(mymodel.pth)
model.eval()
1.3、保存更多信息,如优化器参数
1)保存信息至字典,获取时通过字典获取
保存:
torch.save({'epoch': epochID + 1,
'state_dict':model.state_dict(),
'best_loss': lossMIN,
'optimizer': optimizer.state_dict(),
'alpha': loss.alpha,
'gamma': loss.gamma},
checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f"% lossMIN) + '.pth.tar')
调用:
def load_checkpoint(model, checkpoint_PATH, optimizer):
if checkpoint != None:
model_CKPT = torch.load(checkpoint_PATH)
model.load_state_dict(model_CKPT['state_dict'])
print('loading checkpoint!')
optimizer.load_state_dict(model_CKPT['optimizer'])
return model, optimizer
2)如若修改了网络结构,如增删操作,则需要过滤这些参数,加载方式略有不同
def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
if checkpoint != 'No':
print("loading checkpoint...")
model_dict = model.state_dict()
modelCheckpoint = torch.load(checkpoint)
pretrained_dict = modelCheckpoint['state_dict']
# 过滤操作
new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
model_dict.update(new_dict)
# 打印出来,更新了多少的参数
print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
model.load_state_dict(model_dict)
print("loaded finished!")
# 如果不需要更新优化器那么设置为false
if loadOptimizer == True:
optimizer.load_state_dict(modelCheckpoint['optimizer'])
print('loaded! optimizer')
else:
print('not loaded optimizer')
else:
print('No checkpoint is included')
return model, optimizer
1.4、冻结部分参数,训练另一部分参数(special)
需求较少,后续添加



