栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

2021-10-19

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

2021-10-19

江大白共学计划课程–Pytorch课程作业(1)
import torch
import torchvision.models as models
from thop import profile
# 加载模型结构
model = models.resnet18()
# 读取权重和载入权重
pretrained_state_dict = torch.load('Lesson2/weights/resnet18-5c106cde.pth')
model_state_dict = model.load_state_dict(pretrained_state_dict, strict=False)


# 让模型变为推理状态
model.eval()
# 将模型放置到cuda上
model.to(torch.device('cuda'))

# 构建一个项目推理时需要的输入大小的单精度Tensor,并且放置模型所在的设备(cpu或cuda)
inputs = torch.ones([1, 3, 224, 224]).type(torch.float32).to(torch.device('cuda'))
flops, params = profile(model=model, inputs=(inputs))
print('Model:{:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6))

# 生成onnx:将训练好的模型(包括结构和权重)保存成onnx。
torch.onnx.export(model, inputs, 'Lesson2/weights/resnet18.onnx', verbose = False)
输入:[1, 3, 224, 224]
输出:resnet18模型的参数量和计算量
Model:0.31 GFLOPs and 3.50M parameters

输入:[1, 3, 448, 448]
输出:resnet18模型的参数量和计算量
Model:1.25 GFLOPs and 3.50M parameters
Netron可视化ResNet18的网络结构

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/339912.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号