栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

TensorRTx工程源码解读——YoloV5(一)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

TensorRTx工程源码解读——YoloV5(一)

# Load model model torch.load(pt_file, map_location device)[ model ].float() # load to FP32 model.to(device).eval() with open(wts_file, w ) as f: f.write( {}n .format(len(model.state_dict().keys()))) for k, v in model.state_dict().items(): vr v.reshape(-1).cpu().numpy() f.write( {} {} .format(k, len(vr))) for vv in vr: f.write( ) f.write(struct.pack( f ,float(vv)).hex()) f.write( n )

第一个函数parse_args()就是正常处理输入的命令行参数 不多做赘述。

主函数内 先是设置设备为CPU 再load进pt文件获得model并转成FP32格式。并设置模型的device和eval模式。

设置完毕后 作者保存权重文件 其中权重文件的内容是作者自定义的。第一行存入的是model的keys的个数 再分别遍历pt文件内的每一个权重 保存为该层名称 该层参数量 16进制权重。

权重读取 common.cpp

首先顺着之前的思路 看看作者是如何load权重的。

// TensorRT weight files have a simple space delimited format:
// [type] [size] data x size in hex 
std::map std::string, Weights loadWeights(const std::string file) {
 std::cout Loading weights: file std::endl;
 std::map std::string, Weights weightMap;
 // Open weights file
 std::ifstream input(file);
 assert(input.is_open() Unable to load weight file. please check if the .wts file path is right!!!!!! );
 // Read number of weight blobs
 int32_t count;
 input count;
 assert(count 0 Invalid weight map file. );
 while (count--)
 Weights wt{ DataType::kFLOAT, nullptr, 0 };
 uint32_t size;
 // Read name and type of blob
 std::string name;
 input name std::dec size;
 wt.type DataType::kFLOAT;
 // Load blob
 uint32_t* val reinterpret_cast uint32_t* (malloc(sizeof(val) * size));
 for (uint32_t x 0, y size; x y; x)
 input std::hex val[x];
 wt.values val;
 wt.count size;
 weightMap[name] wt;
 return weightMap;

此为loadWeight()函数。作者此处使用了std::map容器。map容器在OpenCV和OpenVINO中本身就是大量使用的 所以除了vector之外 也需要掌握map的使用。后面需要往这个 std::string, Weights 型的map中添加权重信息。

同时应该注意 此处的Weights类型在TensorRT的NvInferRuntime.h头文件中有定义:

class Weights
public:
 DataType type; //! The type of the weights.
 const void* values; //! The weight values, in a contiguous array.
 int64_t count; //! The number of weights in the array.

作者使用了std::ifstream进行输入流变量的定义 并设置了一些变量。代码中的input count就是将.wts文件中的第一行的算子数传递给count这个变量 从而构建while循环。

在While循环中 作者先定义了Weights型的wt变量 其类型为DataType::kFLOAT values直接初始化为nullptr count初始化一个0在上面即可。

这一句input name std::dec size是将input中的第一部分 权重的名称 赋值给name变量 再将紧跟着name后的size推入给size变量。具体的形式可以参考之前分析gen_wts.py脚本中的权重生成的部分。作者之所以要存入这一算子的权重的size 就是为了方便分配空间大小。声明指针val指向一个大小为sizeof(val) * size的uint32_t的数组 并且将input中这一行的权重全部推入给val这个数组即可。

这一步完成后 设置Weights的values成员为val count成员为size 并将name作为weightMap的keys wt作为其values即可。

至此 模型权重加载完毕。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/266805.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号