使用tensorrt进行加速的时候,遇到
需要自已定义torch.bmm进行替换。
- 进入到安装torch2trt所在的converters目录,例如我的在/opt/conda/lib/python3.7/site-packages/torch2trt-0.3.0-py3.7-linux-x86_64.egg/torch2trt/converters/
- 创建bmm.py文件并写入以下内容
from torch2trt.torch2trt import *
# converter added
@tensorrt_converter('torch.bmm')
@tensorrt_converter('torch.Tensor.bmm')
def convert_bmm(ctx):
input_a = ctx.method_args[0]
input_b = ctx.method_args[1]
output = ctx.method_return
input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b)
layer = ctx.network.add_matrix_multiply(input_a_trt,trt.MatrixOperation.NONE, input_b_trt,trt.MatrixOperation.NONE)
output._trt = layer.get_output(0)
- 打开与当前目录同级下的__init__.py文件,并添加以下内容即可。
from .bmm import convert_bmm



