栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch模型转onnx_pytorch模型转paddle?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch模型转onnx_pytorch模型转paddle?

pytorch模型转tflite 参考文档:

1.https://blog.csdn.net/computerme/article/details/84144930

2.https://blog.csdn.net/qq_40600539/article/details/123142541

配置环境:
# tensorflow        2.4.0
# onnx              1.8.0
# onnx-tensorflow   1.8.0 [onnx-tf]
# tf-nightly        2.9.0
# pytorch           1.8.0
参考代码
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
from onnxsim import simplify
import onnxruntime as ort
import numpy as np
import torch.nn as nn
import torch

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 1, groups=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.feature = nn.Sequential(conv1, conv2)
        self.init_weights()

    def forward(self, x):
        return self.feature(x)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight.data, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

if __name__ == '__main__':
    model = Model()
    # Converting model to ONNX
    for _ in model.modules():
        _.training = False

    test_arr = np.random.randn(1, 3, 480, 640).astype(np.float32)
    sample_input = torch.tensor(test_arr)
    # sample_input = torch.randn(1, 3, 480, 640)
    input_nodes = ['input']
    output_nodes = ['output']

    model(sample_input)

    torch.onnx.export(model, sample_input, "model.onnx", export_params=True, input_names=input_nodes,
                      output_names=output_nodes, opset_version=11)

    model = onnx.load("model.onnx")
    ort_session = ort.InferenceSession('model.onnx')
    onnx_outputs = ort_session.run(None, {'input': test_arr})
    print('Export ONNX!')

    onnx_model = onnx.load("model.onnx")
    model_simp, check = simplify(onnx_model)
    assert check, "Simplified onNX model could not be validated"

    output = prepare(model_simp)
    output.export_graph("tf_model/")
    print('Export tf_model!')

    converter = tf.lite.TFLiteConverter.from_saved_model("tf_model")
    tflite_model = converter.convert()
    open("model.tflite", "wb").write(tflite_model)
    print('Export tf lite model!')
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/786871.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号