本篇自学笔记来自于b站《PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】》,Up主讲的非常通俗易懂,文章下方有视频连接,如有需要可移步up主讲解视频,如有侵权,实非故意,深表歉意,请与我联系,删除相关内容!
本节将介绍的内容有:1.使用torchvison中定义好的模型 ,2.如何修改定义好的模型,3.保存模型的方式和对应的加载模型的方式。
1.使用定义好的模型(以VGG16为例)
给出官方解释:如图可以看到有两个字段,分别为pretrained和progress,pretrained表示是否使用预训练好的模型,该模型是在ImageNet上训练好的。progress则为是否显示下载进度条。
代码为:分别写了pretrained为true和false的两种情况。
vgg16_false= torchvision.models.vgg16(pretrained=False) vgg16_true = torchvision.models.vgg16(pretrained=True)
2.修改上述模型
首先输出VGG16的网络模型如下:
首先演示给模型的classifier中添加一层Linear层:
vgg16_true.classifier.add_module(name="add_linear",module=nn.Linear(in_features=1000,out_features=10))
其中add_module的官方定义如下,需要给出添加的名字,以及具体添加的某一层的定义。
添加完之后结果如下: 其次也可以在模型中直接修改,接下来展示的是直接修改classifier中(6)Linear层。
vgg16_false.classifier[6] = nn.Linear(in_features=4096,out_features=10)
结果如下:
3.网络模型的保存和加载
保存方式主要有两种,对应的加载方式也是两种,下面展示第一种保存方式,这种保存方式保存了模型的结构和参数。第一个参数为要保存的模型名,第二个参数为要保存的路径。
torch.save(vgg16,"vgg16_method1.pth")
对应的读取方式为:
vgg_method1 = torch.load("vgg16_method1.pth")
读取的结果如下图,可以看到确实保存了模型的结构和参数
第二种保存方式,这种保存方式是以字典的形式保存了模型参数,而不保存模型结构,所以对应的读取方式也会不同。
#保存方式2
#只保存模型参数
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
#对应的读取
vgg16_method2 = vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16_method2)
附视频地址:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili



