- 一、pytorch中现有网络模型的使用、修改
- 二、模型的保存和加载
- 1. 模型的保存
- 2.模型的加载
-
位于torchvision.models
-
使用vgg模型为例,采用的数据集是ImageNet,而ImageNet数据集使用前提需要有scipy包
pip install scipy注意:ImageNet光训练集就有147.9G,而且不再能公开访问了
-
pytorch中使用现有网络模型以及修改现有的网络模型代码示例
import torchvision
# train_data = torchvision.datasets.ImageNet("../data_image_net", split="train", download=True,
# transform=torchvision.transforms.ToTensor())
from torch import nn
"""
理解:
1. pretrained=False时,相当于使用pytorch中现有的网络模型,其中各层的参数采用默认的
2. pretrained=True时,相当于使用pytorch中现有的网络模型,但其中各层的参数采用 我们在数据集上训练好的参数
"""
# 1.使用现有的网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 2.在现有的网络模型中添加一层
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
# 3.修改现有网络中的某层的参数
vgg16_false.classifier[7] = nn.Linear(4096, 10)
二、模型的保存和加载
1. 模型的保存
import torch
import torchvision
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,保存了网络模型的结构以及其中的参数
torch.save(vgg16, "vgg16_method1.pth")
# 保存方式2,把网络模型的参数保存成字典,不再保存网络模型的结构(官方推荐)占的空间小
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 陷阱,用方式1保存自己写的神经网络
class MyNeural(nn.Module):
def __init__(self):
super(MyNeural, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
my_neural = MyNeural()
torch.save(my_neural, "my_neural_method1.pth")
2.模型的加载
import torch
import torchvision
from c17_model_save import *
vgg16 = torchvision.models.vgg16(pretrained=False)
# 加载方式1,对应保存方式1
model = torch.load("vgg16_method1.pth")
print(model)
# 加载方式2,对应保存方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
# 陷阱1,
# 要让该.py文件加载自己定义的神经网络,需要引入自己定义的神经网络的模板类 from c17_model_save import *
model = torch.load("my_neural_method1.pth")
print(model)



