栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > C/C++/C#

【TVM源码学习笔记】5. C++侧的relay ir op

C/C++/C# 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【TVM源码学习笔记】5. C++侧的relay ir op

在【TVM源码学习笔记】4. tvm relay python调用C++的流程 中分析到onnx卷积算子在转为tvm relay ir卷积算子时,在C++侧的MakeConv接口中生成了一个Op类实例:

const Op& op = Op::Get(op_name);

这个Op类的定义include/tvm/ir/op.h。Op::Get的实现:

const Op& Op::Get(const String& name) {
  const OpRegEntry* reg = OpRegistry::Global()->Get(name);
  ICHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered";
  return reg->op();
}

using OpRegistry = AttrRegistry;



template 
class AttrRegistry {
 public:
  using TSelf = AttrRegistry;
  
  const EntryType* Get(const String& name) const {
    auto it = entry_map_.find(name);
    if (it != entry_map_.end()) return it->second;
    return nullptr;
  }

  
  EntryType& RegisterOrGet(const String& name) {
    auto it = entry_map_.find(name);
    if (it != entry_map_.end()) return *it->second;
    uint32_t registry_index = static_cast(entries_.size());
    auto entry = std::unique_ptr(new EntryType(registry_index));
    auto* eptr = entry.get();
    eptr->name = name;
    entry_map_[name] = eptr;
    entries_.emplace_back(std::move(entry));
    return *eptr;
  }

 ...

private:
 ...
 // entries in the registry
 std::vector> entries_;
 // map from name to entries.
 std::unordered_map entry_map_;
 ...
}

Op::Get调用OpRegistry::Global()->Get(name)返回一个实例。OpRegistry::Global()->Get则是从entry_map_中按照字符串name查找元素。而这个entry_map_表中的项是通过RegisterOrGet写入的。在RegisterOrGet中,先是在entry_map_表中查找key为name表项是否存在;如果存在,直接返回该表象的value;如果不存在,就new一个EntryType类型实例,然后用get获取类型指针(由entry_map_的定义倒推,eptr为EntryType*类型),将这个数据实例加入到entry_map_表和entries_表中。

搜索这个方法,可以看到在RELAY_REGISTER_OP宏的定义中有使用:

#define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName)


#define TVM_STR_CONCAt_(__x, __y) __x##__y
#define TVM_STR_CONCAt(__x, __y) TVM_STR_CONCAt_(__x, __y)

#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op


#define TVM_REGISTER_OP(OpName)                          
  TVM_STR_CONCAt(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = 
      ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()

RELAY_REGISTER_OP用于注册一个算子。TVM_OBJECT_REG_VAR_DEF定义了一个静态变量,包括变量类型 (::tvm::OpRegEntry&)和变量名的前半部分;TVM_STR_CONCAT将__COUNTER__和前半部分拼接在一起成为完整的变量名。__COUNTER__宏是一个计数器,保证在编译过程中产生一个独一无二的数字,这样这个拼接后的变量名也将是独一无二的。

RELAY_REGISTER_OP宏的详细分析可以参考
深入理解TVM:RELAY_REGISTER_OP

算子实现和注册可以参考tvm手册

Adding an Operator to Relay 

例如conv2d的注册:

RELAY_REGISTER_OP("nn.conv2d")
    .describe(R"code(2D convolution layer (e.g. spatial convolution over images).

This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
            (batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **out**:  This depends on the `layout` parameter. Output is 4D array of shape
            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.

)code" TVM_ADD_FILELINE)
    .set_attrs_type()
    .set_num_inputs(2)
    .add_argument("data", "Tensor", "The input tensor.")
    .add_argument("weight", "Tensor", "The weight tensor.")
    .set_support_level(2)
    .add_type_rel("Conv2D", Conv2DRel)
    .set_attr("FInferCorrectLayout", ConvInferCorrectLayout);

也就是说当tvm中实现一个算子时,会调用 RELAY_REGISTER_OP进行注册,该注册会在 AttrRegistry(这是个单例模式的类)的entry_map_中加入一个OpRegEntry实例。而tvm处理一个外部输入的模型时,如果遇到这个算子,就从entry_map_表中读取对应的OpRegEntry实例,然后调用OpRegEntry::op(见Op::Get方法),获取对应的Op实例:

class OpRegEntry {
 public:
  
  const Op& op() const { return op_; }
  ...
 private:
  ...
  
  Op op_;
  ...
}

OpRegEntry::OpRegEntry(uint32_t reg_index) {
  ObjectPtr n = make_object();
  n->index_ = reg_index;
  op_ = Op(n);
}

Op类继承自ObjectRef,对应的数据类型为OpNode。在该类中记录了算子的名称,类型,属性,输入等信息。还提供了属性的访问入口VisitAttrs

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/766925.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号