栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

网络模型的保存与读取

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

网络模型的保存与读取

网络模型的保存于与读取 方法1: 1.1 如何保存网络模型

首先,创建一个py文件,model_save.py

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16,"vgg16_model1_pth")

运行结束后我们会在我们左侧的文件出现vgg16_model1_pth这个文件
用这种方法保存,不仅保存了网络模型,也保存了网络模型中的相关参数

1.2 如何读取网络模型

新建一个py文件,model_load.py

import torch

model = torch.load("vgg16_model1_pth")
print(model)

输出:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
方法二 2.1:如何保存网络模型
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

#torch.save(vgg16,"vgg16_model1_pth")
torch.save(vgg16.state_dict(),"vgg16_model2_pth")

也会在左侧形成一个vgg16_model2_pth文件
只保存了模型的参数,占用空间更小,官方推荐方式

2.2:如何读取网络模型

读取方式与方法一 一样,但是输出为字典类型的数据

import torch

# 方式2-> 保存方式2,加载模型
model = torch.load("vgg16_model2_pth")  # 加载出来的是字典类型的数据
print(model)

F:Anaconda3envspytorchpython.exe D:/Python/learn_torch/model_load.py
OrderedDict([('features.0.weight', tensor([[[[ 3.9726e-02, -4.0263e-02,  5.2152e-02],
          [ 3.5984e-02, -4.6239e-02, -2.4924e-02],
          [-9.6867e-03,  1.2961e-02, -4.5731e-02]],

         [[ 1.9925e-03,  3.6464e-02,  5.6411e-02],
          [-9.0956e-02, -3.6801e-02, -7.3917e-02],
          [ 3.6363e-02, -4.5585e-02, -8.2003e-03]],

         [[-1.1151e-01, -2.4694e-02, -3.4446e-02],
          [-5.4018e-02,  7.9030e-02,  1.1468e-01],
          [ 6.1839e-02, -8.7451e-02,  2.8596e-03]]],


        [[[-6.4775e-02,  5.2936e-03, -1.8106e-02],
          [-4.0254e-02, -8.5685e-02, -7.8011e-02],
          [ 1.1739e-02, -7.9629e-02,  6.6174e-02]],

         [[-1.1657e-01,  3.5422e-02,  6.2663e-02],
          [ 3.0534e-02,  6.9120e-03,  3.3340e-03],
          [-1.5356e-01,  7.2058e-02,  4.7606e-02]],

         [[-1.2942e-01, -3.5475e-02,  9.7374e-02],
          [-1.3898e-02, -2.5312e-02,  6.3060e-02],
          [ 5.4231e-04,  1.4181e-02,  8.3530e-02]]],


        [[[-1.5726e-03,  6.0129e-02, -2.5256e-02],
          [-8.2932e-02,  9.2577e-02,  1.8457e-02],
          [-5.7204e-02, -5.2296e-02,  8.6386e-02]],

         [[-3.1392e-02,  1.2295e-01, -6.2096e-03],
          [-1.6034e-02,  3.0497e-03,  5.9402e-02],
          [-7.5480e-02, -6.9659e-02, -1.2263e-02]],

         [[ 6.5706e-05, -4.6442e-02,  6.1466e-02],
          [ 3.6150e-02,  3.6947e-02, -9.4802e-02],
          [ 7.0997e-02,  1.2181e-02,  3.3660e-03]]],
          .....................................
          ....................................
          ..................................

从上述输出结果中得到的结果是字典类型,其中参数的值也一起输出来了,如果想要查看具体的网络结构,需要这样

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_model2_pth"))  # 输出完整的模型结构,与第一种方式输出的模型结构相同
print(vgg16)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/738587.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号