栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

MMClassificatio 框架下 Pytorch模型转TensorRT

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

MMClassificatio 框架下 Pytorch模型转TensorRT

模型的加载
import torchvision.models as models
resnet34 = models.resnet34()
resnet34.load_state_dict(torch.load('latest.pth')['model'])
要解决的疑问
  • load_state_dict torch.load作用
    网络结构有了 这部分是在加载参数
  • dummy input作用
    给网络一个输入
  • 如果dynamic_axes 后面输入可以更改指定的维度
  • binding inputname outputname作用
    binding 每个engine有且只有两个binding,对应输入输出
    name可以理解为指针,在转onnx时候就指定根据这个指针拿到输入输出的内容
dummy_input=torch.randn(BATCH_SIZE, 3, 224, 224)
import torch.onnx
torch.onnx.export(resnet34, dummy_input, "rp_rec.onnx", verbose=False)
注意

torchvision和mmcls的Resnet模型不一样

resnet34 = models.resnet34()
resnet34.load_state_dict(torch.load('latest.pth')['model'])

模型必须和参数对应起来
不能用torchvision的模型加载mmcls的参数

Pytorch转TensorRT方法总结

采用mmclassification框架,根据网络推理时的输入指定网络输入dummy_input,看推理代码,如果网络允许某个维度有变化,那么可以设定dynamic_axes(某个维度定死了,就不要dynamic_axes),采用verify参数,对比模型的输出是否一致

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/499189.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号