在GraphProto.from_onnx中,遍历onnx模型的各个节点,将节点转换为tvm ir:
for node in graph.node:
# 获取算子类型和属性
op_name = node.op_type
attr = self._parse_attr(node.attribute)
# Create and populate input list.
# 创建一个(算子)输入实例
inputs = onnx_input()
# 获取当前(onnx)节点的所有输入(name)
for i in node.input:
if i != "":
inputs.append(self._nodes[self._renames.get(i, i)])
else:
# 有些输入没使用?
inputs.append(None)
i_name = self._parse_value_proto(node)
# 获取onnx节点的输出,为string类型,是输出的name
node_output = self._fix_outputs(op_name, node.output)
# 记录onnx节点的属性
attr["tvm_custom"] = {}
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(node_output)
# 将onnx 算子节点转换为对应的tvm表示.
op = self._convert_operator(op_name, inputs, attr, opset)
还是以tvm源码里面的mnist onnx模型为例,我们可以在这段代码里打印当前处理的onnx节点name、输出以及转换后的op类型和数据等:
################################################ onnx op node Times212_reshape1 output: ['Parameter193_reshape1'] convert to tvm op:free_var %Parameter193: Tensor[(16, 4, 4, 10), float32]; reshape(%Parameter193, newshape=[256, 10]) ##################################################### ################################################ onnx op node Convolution28 output: ['Convolution28_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]) ##################################################### ################################################ onnx op node Plus30 output: ['Plus30_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; add(%1, %Parameter6) ##################################################### ################################################ onnx op node ReLU32 output: ['ReLU32_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; %2 = add(%1, %Parameter6); nn.relu(%2) ##################################################### ################################################ onnx op node Pooling66 output: ['Pooling66_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; %2 = add(%1, %Parameter6); %3 = nn.relu(%2); nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) ##################################################### ################################################ onnx op node Convolution110 output: ['Convolution110_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; %2 = add(%1, %Parameter6); %3 = nn.relu(%2); %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]); %5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter87: Tensor[(16, 8, 5, 5), float32]; nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]) ##################################################### ################################################ onnx op node Plus112 output: ['Plus112_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; %2 = add(%1, %Parameter6); %3 = nn.relu(%2); %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]); %5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter87: Tensor[(16, 8, 5, 5), float32]; %6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]); free_var %Parameter88: Tensor[(16, 1, 1), float32]; add(%6, %Parameter88) ##################################################### ################################################ onnx op node ReLU114 output: ['ReLU114_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; %2 = add(%1, %Parameter6); %3 = nn.relu(%2); %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]); %5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter87: Tensor[(16, 8, 5, 5), float32]; %6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]); free_var %Parameter88: Tensor[(16, 1, 1), float32]; %7 = add(%6, %Parameter88); nn.relu(%7) ##################################################### ################################################ onnx op node Pooling160 output: ['Pooling160_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; %2 = add(%1, %Parameter6); %3 = nn.relu(%2); %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]); %5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter87: Tensor[(16, 8, 5, 5), float32]; %6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]); free_var %Parameter88: Tensor[(16, 1, 1), float32]; %7 = add(%6, %Parameter88); %8 = nn.relu(%7); nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]) ##################################################### ################################################ onnx op node Times212_reshape0 output: ['Pooling160_Output_0_reshape0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; %2 = add(%1, %Parameter6); %3 = nn.relu(%2); %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]); %5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter87: Tensor[(16, 8, 5, 5), float32]; %6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]); free_var %Parameter88: Tensor[(16, 1, 1), float32]; %7 = add(%6, %Parameter88); %8 = nn.relu(%7); %9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]); reshape(%9, newshape=[1, 256]) ##################################################### ################################################ onnx op node Times212 output: ['Times212_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; %2 = add(%1, %Parameter6); %3 = nn.relu(%2); %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]); %5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter87: Tensor[(16, 8, 5, 5), float32]; %6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]); free_var %Parameter88: Tensor[(16, 1, 1), float32]; %7 = add(%6, %Parameter88); %8 = nn.relu(%7); %9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]); free_var %Parameter193: Tensor[(16, 4, 4, 10), float32]; %10 = reshape(%Parameter193, newshape=[256, 10]); %11 = reshape(%9, newshape=[1, 256]); %12 = transpose(%10, axes=[1, 0]); nn.dense(%11, %12, units=None, out_dtype="float32") ##################################################### ################################################ onnx op node Plus214 output: ['Plus214_Output_0'] convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32]; %0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter5: Tensor[(8, 1, 5, 5), float32]; %1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); free_var %Parameter6: Tensor[(8, 1, 1), float32]; %2 = add(%1, %Parameter6); %3 = nn.relu(%2); %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]); %5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); free_var %Parameter87: Tensor[(16, 8, 5, 5), float32]; %6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]); free_var %Parameter88: Tensor[(16, 1, 1), float32]; %7 = add(%6, %Parameter88); %8 = nn.relu(%7); %9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]); free_var %Parameter193: Tensor[(16, 4, 4, 10), float32]; %10 = reshape(%Parameter193, newshape=[256, 10]); %11 = reshape(%9, newshape=[1, 256]); %12 = transpose(%10, axes=[1, 0]); %13 = nn.dense(%11, %12, units=None, out_dtype="float32"); free_var %Parameter194: Tensor[(1, 10), float32]; add(%13, %Parameter194) #####################################################
我们可以看到返回的op数据不仅仅转换了当前onnx节点,还有将当前节点的输入节点也叠加进来了。打印inputs会看到,这个叠加的部分源自inputs参数。最终的模型的输出节点表示,是整个网络的计算过程。完成这些转换的_convert_operator函数代码如下:
def _convert_operator(self, op_name, inputs, attrs, opset):
"""Convert onNX operator into a Relay operator.
The converter must specify conversions explicitly for incompatible name, and
apply handlers to operator attributes.
Parameters
----------
op_name : str
Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.function.Function
List of inputs.
算子的输入,类型为tvm.relay.function.Function
attrs : dict
Dict of operator attributes
opset : int
Opset version
算子的版本号
Returns
-------
sym : tvm.relay.function.Function
Converted relay function
"""
# 获取onnx算子与tvm的映射表
convert_map = _get_convert_map(opset)
# 如果当前onnx算子在_identity_list表中
if op_name in _identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
# 如果算子在映射表中
elif op_name in convert_map:
#执行转换
sym = convert_map[op_name](inputs, attrs, self._params)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
# 返回转换结果
return sym
从这个函数里面可以看到,onnx算子在tvm中的映射有两个来源,一个是_get_convert_map返回的映射表,一个是_identity_list,如果这两个表中都没有,那么当前onnx算子就是不支持的。
搜索代码可以看到_identity_list在输入为onnx模型、caffe模型和tensorflow模型时都为空,只有在输入为mxnet时不为空,主要是一些数学计算函数,见mxnet.py:
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
"abs",
"log",
"exp",
"erf",
"sqrt",
"floor",
"ceil",
"round",
"trunc",
"sign",
"sigmoid",
"negative",
"reshape_like",
"zeros_like",
"ones_like",
"cos",
"cosh",
"sin",
"sinh",
"tan",
"tanh",
"where",
]
猜测这个表里面的算子是在mxnet和tvm之间不用变换,直接可以使用的。对onnx模型而言_identity_list为空,所以所有算子的映射都来源于,_get_convert_map返回的映射表:
# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
def _get_convert_map(opset):
return {
# defs/experimental
"Identity": Renamer("copy"),
"Affine": Affine.get_converter(opset),
"BitShift": BitShift.get_converter(opset),
"ThresholdedRelu": ThresholdedRelu.get_converter(opset),
"ScaledTanh": ScaledTanh.get_converter(opset),
"ParametricSoftplus": ParametricSoftPlus.get_converter(opset),
"Constant": Constant.get_converter(opset),
"ConstantOfShape": ConstantOfShape.get_converter(opset),
# 'GivenTensorFill'
"FC": AttrCvt("dense", ignores=["axis", "axis_w"]),
"Scale": Scale.get_converter(opset),
...
}
从注释看,映射包括三种场景:
1. onnx算子和tvm算子仅仅只是名字不一致,这种是1对1映射的,调用tvm算子的Renamer即可。如果参数需要变换,就调用AttrCvt转换参数;
2. 一个onnx算子由多个tvm算子组成,这种需要调用算子的get_converter函数;
3. 多个onnx算子合成一个tvm算子,当前不支持。
我们以卷积算子为例,在_convert_operator中转换时调用了Conv.get_converter(opset)(inputs, attrs, self._params)。这里Conv类继承自OnnxOpConverter, 而get_converter是OnnxOpConverter的方法,定义如下:
class onnxOpConverter(object):
"""A helper class for holding onnx op converters."""
@classmethod
def get_converter(cls, opset):
"""Get converter matches given opset.
Parameters
----------
opset: int
opset from model.
Returns
-------
converter, which should be `_impl_vx`. Number x is the biggest
number smaller than or equal to opset belongs to all support versions.
"""
# dir(cls)得到的是类的属性,包括特殊成员变量, 普通成员变量和方法
# 这里是在这些属性名中查找有没有包含字符串_impl_v的.找到了就将_impl_v去掉,
# 剩下部分转换为int类型. 而各算子的_impl_vx属性是算子变换方法,x为整数表示版本
# 所以这里versions得到的是算子支持的所有版本号
versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d]
# opset为当前传入的版本号,将这个参数版本号加入到版本号表中,并从小到大排序
versions = sorted(versions + [opset])
# max语句得到opset在版本号表中的下标,然后减1就是比opset的前一个.因为versions是排序过的,
# 所以这个元素大于等于opset.所以这里得到的是仅次于(小于等于)opset的版本号
version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1]
# 将这个得到的版本号和_impl_v拼接得到一个方法名.如果算子类有该方法,就返回该方法的句柄.否则就报错版本不支持
if hasattr(cls, "_impl_v{}".format(version)):
return getattr(cls, "_impl_v{}".format(version))
raise NotImplementedError(
"opset version {} of {} not implemented".format(version, cls.__name__)
)
我们看下Conv类支持的_impl_vx方法:
class Conv(OnnxOpConverter):
"""Operator converter for Conv."""
@classmethod
def _impl_v1(cls, inputs, attr, params):
# Use shape of input to determine convolution type.
# 从传入的inputs参数中获取输入和卷积核数据,并推导各自的形状
data = inputs[0]
kernel = inputs[1]
input_shape = infer_shape(data)
ndim = len(input_shape)
kernel_type = infer_type(inputs[1])
kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)]
# 如果onnx卷积属性中没有给出卷积核的形状,就使用inputs里面推导出来的形状
if "kernel_shape" not in attr:
attr["kernel_shape"] = kernel_shapes[0][2:]
# 如果onnx卷积算子设置了auto_pad属性
if "auto_pad" in attr:
# 对用的tvm卷积算子也使用onnx设置的auto_pad属性值
attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
# 根据auto_pad属性值对数据进行填充处理
if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
# Warning: Convolution does not yet support dynamic shapes,
# one will need to run dynamic_to_static on this model after import
# 对输入数据进行填充,得到填充后的数据
data = autopad(
data,
attr.get("strides", [1] * (ndim - 2)),
attr["kernel_shape"],
attr.get("dilations", [1] * (ndim - 2)),
mode=attr["auto_pad"],
)
elif attr["auto_pad"] == "VALID":
attr["pads"] = [0 for i in range(ndim - 2)]
elif attr["auto_pad"] == "NOTSET":
pass
else:
msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))
attr.pop("auto_pad")
attr["channels"] = kernel_shapes[0][0]
out = AttrCvt(
# 返回的op_name是一个函数,在AttrCvt.__call__方法中调用该函数,根据当前attr中kernel_shape
# 属性得到对应的TVM conv1d/conv2d/conv3d算子接口;然后算子接收([data, kernel], attr, params)
# 参数, 返回转换后的TVM表示out
op_name=dimension_picker("conv"),
transforms={
"kernel_shape": "kernel_size",
"dilations": ("dilation", 1),
"pads": ("padding", 0),
"group": ("groups", 1),
},
custom_check=dimension_constraint(),
)([data, kernel], attr, params)
use_bias = len(inputs) == 3
# 如果输入中有偏置参数,则在表达式中添加偏置运算
if use_bias:
out = _op.nn.bias_add(out, inputs[2])
return out
在_impl_v1中对卷积的输入数据,卷积核参数,以及填充做了初步的处理,然后创建一个AttrCvt实例。传入的参数op_name是一个函数,在AttrCvt.__call__方法中会调用该方法,参数为当前卷积的attr。根据attr中的kernel_shape参数,判断当前是1d/2d/3d卷积,得到对应的tvm算子名称conv1d/conv2d/conv3d;传入的transforms参数,用作AttrCvt.__call__中对当前attr和权重参数转换,会转换为tvm的卷积需要的参数形式;custom_check参数用于检查参数,这里对于卷积来说,是检查当前卷积维度是否合法(1d/2d/3d)。
AttrCvt.__call__方法大致流程是对参数进行检查,转换,然后调用get_relay_op得到算子对应的tvm接口函数,将当前算子的输入和变换后的参数输入接口,得到onnx node对应的tvm relay ir。
AttrCvt是一个公共类,不仅仅针对onnx模型,AttrCvt.__call__的详细处理流程比较复杂,目前也没搞太明白。get_relay_op接口代码:
def get_relay_op(op_name):
"""Get the callable function from Relay based on operator name.
Parameters
----------
op_name : str
The Relay operator name.
"""
if "." in op_name:
# explicit hierarchical modules
op = _op
try:
for opn in op_name.split("."):
op = getattr(op, opn)
except AttributeError:
op = None
else:
# try search op in various modules
for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib):
op = getattr(candidate, op_name, None)
if op is not None:
break
if not op:
raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))
return op
这个函数中处理两种格式的op_name,一种是以点做分隔符,比如tvm.relay.op.nn.conv2d; else分支处理的是以‘/’为分隔符的路径形式,例如tvm/relay/op/nn/conv2d。两个分支都是去掉路径,得到最后的算子接口名称(如conv2d)。python/tvm/relay/op/下是tvm relay算子的实现文件存放目录,该目录下有nn、image、vision、contrib等目录,分别存放各类算子。在mnist-8.onnx模型解析时,传入的op_name直接是算子接口名,没有分隔符,所以这里直接返回了。
从打印的tvm relay ir看,这里的conv2d是nn.conv2d,代码如下:
def conv2d(
data,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype="",
):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(strides, int):
strides = (strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation)
# TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.conv2d(
data,
weight,
strides,
padding,
dilation,
groups,
channels,
kernel_size,
data_layout,
kernel_layout,
out_layout,
out_dtype,
)
这里_make是在python/tvm/relay/op/nn/__init__.py中导入的同目录下的_make.py:
import tvm._ffi
tvm._ffi._init_api("relay.op.nn._make", __name__)
_init_api的定义在python/tvm/_ffi/registry.py中
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
namespace : str
The namespace of the source registry
target_module_name : str
The target module name if different from namespace
"""
target_module_name = target_module_name if target_module_name else namespace
if namespace.startswith("tvm."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)
这里传入的参数namespace为relay.op.nn._make, target_module_name参数为_make.py的__name__属性,即_make.py的路径tvm.relay.op.nn._make。这样传入_init_api_prefix的参数将是 tvm.relay.op.nn._make和relay.op.nn._make。
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if not name.startswith(prefix):
continue
fname = name[len(prefix) + 1 :]
target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = "TVM PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff)
module = sys.modules[module_name]获取的是tvm.relay.op.nn._make模块的句柄,然后调用list_global_func_names()获取当前所有全局函数,这里的全局函数都是定义在C++中并注册到python端供调用的。然后将所有以prefix(relay.op.nn._make)开头的函数打包,调用setattr设置为tvm.relay.op.nn._make模块的属性。所以前面nn.py中conv2d调用的_make.conv2d其实是这里设置的C++接口在python端的映射接口。
这里仅以卷积算子为例分析。tvm/relay/op下各类算子目录下都有_make.py文件,会给各模块设置对应算子的C++接口映射属性。这样走到各python端算子接口调用时,最终调用到C++端的对应实现。



