import torchvision.models as models
resnet34 = models.resnet34()
resnet34.load_state_dict(torch.load('latest.pth')['model'])
要解决的疑问
- load_state_dict torch.load作用
网络结构有了 这部分是在加载参数 - dummy input作用
给网络一个输入 - 如果dynamic_axes 后面输入可以更改指定的维度
- binding inputname outputname作用
binding 每个engine有且只有两个binding,对应输入输出
name可以理解为指针,在转onnx时候就指定根据这个指针拿到输入输出的内容
dummy_input=torch.randn(BATCH_SIZE, 3, 224, 224) import torch.onnx torch.onnx.export(resnet34, dummy_input, "rp_rec.onnx", verbose=False)注意
torchvision和mmcls的Resnet模型不一样
resnet34 = models.resnet34()
resnet34.load_state_dict(torch.load('latest.pth')['model'])
模型必须和参数对应起来
不能用torchvision的模型加载mmcls的参数
采用mmclassification框架,根据网络推理时的输入指定网络输入dummy_input,看推理代码,如果网络允许某个维度有变化,那么可以设定dynamic_axes(某个维度定死了,就不要dynamic_axes),采用verify参数,对比模型的输出是否一致



