栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Python环境下将ONNX模型转为fp16 半精度浮点方式

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Python环境下将ONNX模型转为fp16 半精度浮点方式

背景

在TX2上和NX上跑自己想要的模型还是有点慢,由于Jetpack4.6.2的TensorRT8.2对于有16G内存的NX支持存在问题运行不了(8G内存没有问题),可以运行的TensorRT7不支持我这边模型用到的einsum操作,所以我先想着改成fp16运行下看看

参考

https://blog.csdn.net/znsoft/article/details/114538684

流程
  1. 参考代码其实挺简单,但是python环境安装过程有点坎坷,建议新建一个虚拟环境来安装,好像有人把环境都直接装崩了
  2. 新建python3.7的虚拟环境,我是新建了基于python3.7的conda环境,注意哈,截至目前20220513来说这个winmltools无法在python3.8安装,build wheel会报错卡住,所以我最后安装的3.7的python,顺带吐槽下这破东西怎么要装这么多个版本的scipy还是啥的,就离谱
  3. 直接命令行安装:
pip install winmltools
  1. 安装好之后大概就可以按照下面代码把模型修改了:
from winmltools.utils import convert_float_to_float16
from winmltools.utils import load_model, save_model
onnx_model = load_model('model.onnx')
new_onnx_model = convert_float_to_float16(onnx_model)
save_model(new_onnx_model, 'model_fp16.onnx')
报错

我这边这个模型碰到了小问题,报错:

(op_type:AveragePool, name:AveragePool_141): Inferred shape and existing shape differ in dimension 2: (8) vs (7)
Traceback (most recent call last):
  File "G:/jupyter/fp16_convert/fp16_convert.py", line 4, in 
    new_onnx_model = convert_float_to_float16(onnx_model)
  File "D:ProgramDataAnacondaenvsfp16_convertlibsite-packagesonnxconverter_commonfloat16.py", line 139, in convert_float_to_float16
    model = func_infer_shape(model)
  File "D:ProgramDataAnacondaenvsfp16_convertlibsite-packagesonnxshape_inference.py", line 36, in infer_shapes
    inferred_model_str = C.infer_shapes(model_str)
RuntimeError: Inferred shape and existing shape differ in dimension 2: (8) vs (7)

Process finished with exit code 1

由于我是验证过的,可能是其他模型转onnx遇到了点小bug,把它infer那一段跳过就好了。根据报错内容跳转到shape_inference.py中,作如下修改:

def infer_shapes(model):  # type: (ModelProto) -> ModelProto
    if not isinstance(model, ModelProto):
        raise ValueError('Shape inference only accepts ModelProto, '
                         'incorrect type: {}'.format(type(model)))
    model_str = model.SerializeToString()
    return onnx.load_from_string(model_str)
    inferred_model_str = C.infer_shapes(model_str)
    return onnx.load_from_string(inferred_model_str)

重新运行代码,生成成功,放到NX开发板上跑,比float的快了大概1.5倍的样子。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/879273.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号