栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > C/C++/C#

win10 c++调用pytorch模型

C/C++/C# 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

win10 c++调用pytorch模型

1.pytorch模型生成pt模型

"""Export a pth model to Torchscript formats


import time
import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile
from model.model import parsingNet



def main():
    net=“测试代码中调用模型的代码”
    state_dict = torch.load("./model/ep099.pth", map_location='cpu')['model']
    net.load_state_dict(compatible_state_dict, strict=False)
    net.eval()

    # An example input you would normally provide to your model's forward() method.
    example = torch.rand(1,3,288,800).cuda()

    # Use torch.jit.trace to generate a torch.jit.scriptModule via tracing.
    traced_script_module = torch.jit.trace(net, example)
    output = traced_script_module(torch.ones(1,3,288,800).cuda())
    traced_script_module.save("./model/best.pt")

    # The traced scriptModule can now be evaluated identically to a regular PyTorch module
    print(output)


if __name__ == "__main__":
    main()

2. vs2019下配置libtorch

注意libtorch版本和训练模型的pytorch版本一致

3. 使用c++调用pytorch模型

#include 
#include 

int main(void)
{
	
	torch::jit::script::Module module = torch::jit::load("best.pt");

	assert(module != nullptr);

	std::cout << "Model is loaded!" << std::endl;
	// Create a vector of inputs.
	std::vector inputs;
	inputs.push_back(torch::ones({ 1, 3, 288, 800 }).cuda());

	// Execute the model and turn its output into a tensor.
	at::Tensor result = module.forward(inputs).toTensor();

	std::cout << result << std::endl;

	system("pause");

	return 0;
}

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/702787.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号