栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch学习笔记(七)---网络模型的使用修改、保存以及加载

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch学习笔记(七)---网络模型的使用修改、保存以及加载

        本篇自学笔记来自于b站《PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】》,Up主讲的非常通俗易懂,文章下方有视频连接,如有需要可移步up主讲解视频,如有侵权,实非故意,深表歉意,请与我联系,删除相关内容!

        本节将介绍的内容有:1.使用torchvison中定义好的模型 ,2.如何修改定义好的模型,3.保存模型的方式和对应的加载模型的方式。

        1.使用定义好的模型(以VGG16为例)

         给出官方解释:如图可以看到有两个字段,分别为pretrained和progress,pretrained表示是否使用预训练好的模型,该模型是在ImageNet上训练好的。progress则为是否显示下载进度条。

        代码为:分别写了pretrained为true和false的两种情况。

vgg16_false= torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

        2.修改上述模型 

        首先输出VGG16的网络模型如下:

 

        首先演示给模型的classifier中添加一层Linear层:

vgg16_true.classifier.add_module(name="add_linear",module=nn.Linear(in_features=1000,out_features=10))

         其中add_module的官方定义如下,需要给出添加的名字,以及具体添加的某一层的定义。    

        添加完之后结果如下:         其次也可以在模型中直接修改,接下来展示的是直接修改classifier中(6)Linear层。

vgg16_false.classifier[6] = nn.Linear(in_features=4096,out_features=10)

        结果如下: 

        3.网络模型的保存和加载 

        保存方式主要有两种,对应的加载方式也是两种,下面展示第一种保存方式,这种保存方式保存了模型的结构和参数。第一个参数为要保存的模型名,第二个参数为要保存的路径。 

torch.save(vgg16,"vgg16_method1.pth")

         对应的读取方式为:

vgg_method1 = torch.load("vgg16_method1.pth")

        读取的结果如下图,可以看到确实保存了模型的结构和参数

         第二种保存方式,这种保存方式是以字典的形式保存了模型参数,而不保存模型结构,所以对应的读取方式也会不同。

#保存方式2
#只保存模型参数
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
#对应的读取
vgg16_method2 = vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16_method2)

附视频地址:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/618922.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号