栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

将多个GPU上用pytorch框架并行训练的神经网络模型应用到CPU上

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

将多个GPU上用pytorch框架并行训练的神经网络模型应用到CPU上

本人用pytorch框架在两块GPU上并行训练了一个神经网络模型,并将训练的不同阶段的结果保存起来,以便用于模型集成。

虽然模型是在GPU上训练的,但是在服务器上部署的时候只需用CPU就可以进行模型推断。但在实际应用中,却出现如下报错信息:

RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

解决的思路是在服务器上先将模型加载进来,然后用一种新的方式重新保存。见下面的代码段:

    model_list = ['model_1.tar','model_2.tar','model_3.tar','model_4.tar','model_5.tar']
    model_path = './sel_models/'

    # set model
    device = torch.device('cpu')
       
    for model_name in model_list:
        model = Net(num_classes=num_classes, num_channels=num_channels).to(device, dtype=torch.float)
        model = nn.DataParallel(model)
        
        # load trained model
        checkpoint = torch.load(os.path.join(model_path, model_name), map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        del checkpoint
        model = model.to(device, dtype=torch.float)
        
        cpu_model_path = './model_for_cpu/'
        if not os.path.exists(cpu_model_path):
            os.mkdir(cpu_model_path)
        
        torch.save({'model_state_dict': model.module.state_dict()},os.path.join(cpu_model_path, model_name))

这里需要注意的是 必须要加上 model = nn.DataParallel(model), 因为模型是在双GPU上并行训练的,不加这句话模型加载就会出错。另外重新保存的时候一定要加上 'module', 即 model.module.state_dict(), 而不是model.state_dict(),这也是解决这个错误的关键。

本文参考:

(26条消息) pytorch加载多GPU模型和单GPU模型(遗漏module的解决)_律己且好学,才能保证不坠入愤世嫉俗之列。-CSDN博客https://blog.csdn.net/qq_18649781/article/details/90270323?spm=1001.2101.3001.6650.1&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-1.no_search_link&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-1.no_search_link

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/348785.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号