栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > C/C++/C#

c++测试pytorch训练的模型

C/C++/C# 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

c++测试pytorch训练的模型

c++测试pytorch训练的模型

pytorch训练的pth模型转换成onnx模型,使用c++测试onnx模型。

1. 模型.pth转.onnx

化繁为简

写那么多废话不如简单明了

import torch
from Unet import Unet


def pth2onnx(input, pth_path, onnx_path):
    model = Unet()  # 导入自己的网络模型
    model.load_state_dict(torch.load(pth_path))  # 初始化权重
    model.eval()

    torch.onnx.export(model, input, onnx_path, verbose=True)


if __name__ == '__main__':
    pth_path = r'./best_model.pth'  # 训练的pth路径
    onnx_path = r'./best_model.onnx'  # 保存onnx的路径
    model_input = torch.randn(1, 1, 512, 512)  # 模型输入[B,C,H,W]
    pth2onnx(input=model_input, pth_path=pth_path, onnx_path=onnx_path)
(可选)2. 测试.onnx模型转换是否正确

如果第3步模型测试不正确,才需要用第2步来检查是不是模型转换出了问题。

import cv2
import onnxruntime
import numpy as np

onnx_path = './best_model.onnx'  # 上一步生成的onnx模型
image_path = './data/test/1.bmp'  # 测试图像

image = cv2.imread(image_path)  # 读取图像
image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR)  # resize成相应尺寸
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # 灰度处理,使用的是单通道图像预测

# 处理成模型需要的格式,[B,C,H,W]
input = image.reshape(1, 1, image.shape[0], image.shape[1]).astype(np.float32)

session = onnxruntime.InferenceSession(onnx_path)
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input})
print(outputs[0].shape)  # (1, 1, 512, 512)

pred = np.array(outputs[0])[0][0]
pred[pred > 0] = 255
pred[pred <= 0] = 0
cv2.imwrite("pred.bmp", pred)
3. C++测试

使用的是Qt进行的测试。

#include 
#include 
#include 
using namespace cv;
using namespace cv::dnn;
using namespace std;


int main()
{
    int h = 512; int w = 512;
    String modelFile = "F:/project/Unet_model/best_model.onnx";
    String imageFile = "F:/project/Unet_model/data/test/1.bmp";

    Mat img = imread(imageFile); // 读取测试图片
    cvtColor(img, img, cv::COLOR_BGR2GRAY);  // 灰度化
    resize(img, img, Size(h, w));

    Mat inputBolb = blobFromImage(img);  // 转换输入图像的格式[B,C,H,W]

    dnn::Net net = cv::dnn::readNetFromONNX(modelFile); //读取网络和参数
    net.setInput(inputBolb);
    Mat output = net.forward();  // 输出4D mat

    int B = inputBolb.size[0];
    int C = inputBolb.size[1];
    int H = inputBolb.size[2];
    int W = inputBolb.size[3];

    Mat predMat = Mat::zeros(h, w, CV_32F);

    for(int i = 0; i < B; i++){
        for(int j = 0; j < C; j++){
            for(int m = 0; m < H; m++){
                for(int n =0; n < W; n++){

                    float pred = output.ptr(i,j,m)[n];

                    if(pred > 0){
                        predMat.at(m,n) = 255;
                    }
                    else{
                        predMat.at(m,n) = 0;
                    }

                }
            }
        }
    }

    cv::imwrite("F:/QtProject/pred.bmp", predMat);
}
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/665855.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号