栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch - 模型保存及加载

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch - 模型保存及加载

一、pytorch 模型保存及加载
经常看到(.pt,.pth,.pkl)的pytorch模型文件,并不是格式上不同,只是后缀不同;
torch.save函数保存模型文件时,因人而异;
重点在于保存模型的方式不同,需要注意
1.1、只保存模型参数,不保存模型结构
保存:
    # 模型权重参数,不保存模型结构,速度快,占空间少
    torch.save(model.state_dict(), "mymodel.pth")    
调用:
    # 这里需要重新模型结构,My_model
    model = My_model(*args, **kwargs)     
    # 这里根据模型结构,调用存储的模型参数
    model.load_state_dict(torch.load(mymodel.pth))  
    model.eval()
1.2、保存整个模型,包括模型结构和模型参数
保存:
    # 保存整个model的状态
    torch.save(model, mymodel.pth)
调用:
    # 这里已经不需要重构模型结构了,直接load就可以
    model=torch.load(mymodel.pth)
    model.eval()
1.3、保存更多信息,如优化器参数

1)保存信息至字典,获取时通过字典获取

保存:
    torch.save({'epoch': epochID + 1,
                'state_dict':model.state_dict(),
                'best_loss': lossMIN,
                'optimizer': optimizer.state_dict(),
                'alpha': loss.alpha,
                'gamma': loss.gamma},
                checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f"% lossMIN) + '.pth.tar')
调用:
    def load_checkpoint(model, checkpoint_PATH, optimizer):
        if checkpoint != None:
            model_CKPT = torch.load(checkpoint_PATH)
            model.load_state_dict(model_CKPT['state_dict'])
            print('loading checkpoint!')
            optimizer.load_state_dict(model_CKPT['optimizer'])
        return model, optimizer 

2)如若修改了网络结构,如增删操作,则需要过滤这些参数,加载方式略有不同

def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
    if checkpoint != 'No':
        print("loading checkpoint...")
        model_dict = model.state_dict()
        modelCheckpoint = torch.load(checkpoint)
        pretrained_dict = modelCheckpoint['state_dict']
        # 过滤操作
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)
        # 打印出来,更新了多少的参数
        print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
        model.load_state_dict(model_dict)
        print("loaded finished!")
        # 如果不需要更新优化器那么设置为false
        if loadOptimizer == True:
            optimizer.load_state_dict(modelCheckpoint['optimizer'])
            print('loaded! optimizer')
        else:
            print('not loaded optimizer')
    else:
        print('No checkpoint is included')
    return model, optimizer
1.4、冻结部分参数,训练另一部分参数(special)

需求较少,后续添加

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/269292.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号