栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

tensorrt遇到torch.bmm的解决

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

tensorrt遇到torch.bmm的解决

一、问题说明

使用tensorrt进行加速的时候,遇到

需要自已定义torch.bmm进行替换。

二、问题解决
  1. 进入到安装torch2trt所在的converters目录,例如我的在/opt/conda/lib/python3.7/site-packages/torch2trt-0.3.0-py3.7-linux-x86_64.egg/torch2trt/converters/
  2. 创建bmm.py文件并写入以下内容
from torch2trt.torch2trt import *
# converter added
@tensorrt_converter('torch.bmm')
@tensorrt_converter('torch.Tensor.bmm')
def convert_bmm(ctx):
    input_a = ctx.method_args[0]
    input_b = ctx.method_args[1]
    output = ctx.method_return
    input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b)
    layer = ctx.network.add_matrix_multiply(input_a_trt,trt.MatrixOperation.NONE, input_b_trt,trt.MatrixOperation.NONE)
    output._trt = layer.get_output(0)
  1. 打开与当前目录同级下的__init__.py文件,并添加以下内容即可。
from .bmm import convert_bmm
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/275173.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号