栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > C/C++/C#

onnxruntime

C/C++/C# 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

onnxruntime

#include 
#include 
#include 

#include 

#include "cuda_provider_factory.h"
#include 
int main(int argc, char* argv[])
{

    bool useCUDA{ true };

    const wchar_t* model_path = L"squeezenet1.1-7.onnx";
    size_t inputTensorSize = 3 * 224 * 224;
    size_t outputTensorSize = 1000;


    Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "cuda");
    Ort::SessionOptions sessionOptions;
    sessionOptions.SetIntraOpNumThreads(1);
    if (useCUDA)
    {
        // Using CUDA backend
        // https://github.com/microsoft/onnxruntime/blob/v1.8.2/include/onnxruntime/core/session/onnxruntime_cxx_api.h#L329
        OrtCUDAProviderOptions cuda_options{ 0 };
        sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
    }
    sessionOptions.SetGraphOptimizationLevel(
        GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

    Ort::Session session(env, model_path, sessionOptions);

    Ort::AllocatorWithDefaultOptions allocator;

    size_t numInputNodes = session.GetInputCount();
    size_t numOutputNodes = session.GetOutputCount();

    std::cout << "Number of Input Nodes: " << numInputNodes << std::endl;
    std::cout << "Number of Output Nodes: " << numOutputNodes << std::endl;

    const char* inputName = session.GetInputName(0, allocator);
    std::cout << "Input Name: " << inputName << std::endl;

    Ort::TypeInfo inputTypeInfo = session.GetInputTypeInfo(0);
    auto inputTensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo();

    onNXTensorElementDataType inputType = inputTensorInfo.GetElementType();
    std::cout << "Input Type: " << inputType << std::endl;

    std::vector inputDims = inputTensorInfo.GetShape();
    //std::cout << "Input Dimensions: " << inputDims << std::endl;

    const char* outputName = session.GetOutputName(0, allocator);
    std::cout << "Output Name: " << outputName << std::endl;

    Ort::TypeInfo outputTypeInfo = session.GetOutputTypeInfo(0);
    auto outputTensorInfo = outputTypeInfo.GetTensorTypeAndShapeInfo();

    onNXTensorElementDataType outputType = outputTensorInfo.GetElementType();
    std::cout << "Output Type: " << outputType << std::endl;

    std::vector outputDims = outputTensorInfo.GetShape();
    //std::cout << "Output Dimensions: " << outputDims << std::endl;   


    std::vector inputTensorValues(inputTensorSize);
    for (unsigned int i = 0; i < inputTensorSize; i++)
        inputTensorValues[i] = ( float )i / (inputTensorSize + 1);


    std::vector outputTensorValues(outputTensorSize);

    std::vector inputNames{ inputName };
    std::vector outputNames{ outputName };
    std::vector inputTensors;
    std::vector outputTensors;

    Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(
        OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
    inputTensors.push_back(Ort::Value::CreateTensor(
        memoryInfo, inputTensorValues.data(), inputTensorSize, inputDims.data(),
        inputDims.size()));
    outputTensors.push_back(Ort::Value::CreateTensor(
        memoryInfo, outputTensorValues.data(), outputTensorSize,
        outputDims.data(), outputDims.size()));

    session.Run(Ort::RunOptions{ nullptr }, inputNames.data(),
        inputTensors.data(), 1, outputNames.data(),
        outputTensors.data(), 1);

    printf("Done!n");
    system("pause");
    return 0;
}

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/630605.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号