栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch | 模型的保存和加载

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch | 模型的保存和加载

PyTorch | 模型的保存和加载
  • 一、模型参数的保存和加载
  • 二、完整模型的保存和加载

一、模型参数的保存和加载
  • torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path所指定的文件存放路径(常用文件格式为.pt、.pth或.pkl)。
  • torch.nn.Module.load_state_dict(state_dict):从state_dict中加载参数和缓冲区到Module及其子类中 。
  • torch.nn.Module.state_dict()函数返回python中的一个OrderedDict类型字典对象,该对象将每一层与它的对应参数和缓冲区建立映射关系,字典的键值是参数或缓冲区的名称。只有那些参数可以训练的层才会被保存到OrderedDict中,例如:卷积层、线性层等。
  • Python中的字典类以“键:值”方式存取数据,OrderedDict是它的一个子类,实现了对字典对象中元素的排序(OrderedDict根据放入元素的先后顺序进行排序)。由于进行了排序,所以顺序不同的两个OrderedDict字典对象会被当做是两个不同的对象。
  • 示例:
    import torch
    import torch.nn as nn
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x
    
    # 初始化网络
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()
    # 获取state_dict
    state_dict = net.state_dict()
    # 字典的遍历默认是遍历key,所以param_tensor实际上是键值
    for param_tensor in state_dict: 
        print(param_tensor,':n',state_dict[param_tensor])
    # 保存模型参数
    torch.save(state_dict,"net_params.pth")
    # 通过加载state_dict获取模型参数
    net.load_state_dict(state_dict)
    
    输出:
二、完整模型的保存和加载
  • torch.save(module, path):将训练完的整个网络模型module保存到path所指定的文件存放路径(常用文件格式为.pt或.pth)。
  • torch.load(path):加载保存到path中的整个神经网络模型。
  • 示例:
    import torch
    import torch.nn as nn
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x
    
    # 初始化网络
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()
    # 保存整个网络
    torch.save(net,"net.pth")
    # 加载网络
    net = torch.load("net.pth")
    
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/339929.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号