pytorch->onnx->tf->tflite

1.pytorch->onnx

    try:
        import onnx

        print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
        f = opt.weights.replace('.pth', '.onnx').replace('.pt', '.onnx')   # filename
        torch.onnx.export(model, img, f, verbose=False, opset_version=10, input_names=['images'],
                          output_names=['classes', 'boxes'] if y is None else ['output'])

        # Checks
        onnx_model = onnx.load(f)  # load onnx model
        onnx.checker.check_model(onnx_model)  # check onnx model
        # print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
        print('ONNX export success, saved as %s' % f)

2.onnx->tf

    from onnx_tf.backend import prepare
    import onnx
    import tensorflow as tf

    onnx_model = onnx.load("model_vgg6_sim.onnx")  # load onnx model
    tf_rep = prepare(onnx_model)  # prepare tf representation
    tf_rep.export_graph("model_vgg6_sim.tf")  # export the model

此处环境是:

tensorflow-cpu==2.6.0
onnx-tf==2.9.0
python==3.8

3.tf->tflite

        # Convert the model
        converter = tf.lite.TFLiteConverter.from_saved_model('model_vgg6_sim.tf')  # path to the SavedModel directory
        converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
            tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
        ]
        tflite_model = converter.convert()

        # Save the model.
        with open('model.tflite', 'wb') as f:
            f.write(tflite_model)

4. tflite推理

# -*- coding: utf-8 -*-
# @Time : 2021/10/20 9:15 
# @Author : jw hao 
# @File : infrence_tflite.py 
# @Software: PyCharm

# -*- coding:utf-8 -*-
import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import cv2
import numpy as np
import time
from torchvision import datasets, models, transforms
from PIL import Image

import tensorflow as tf

test_image_dir = 'data/test/'
# model_path = "./model/quantize_frozen_graph.tflite"
model_path = "models/model.tflite"

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
print(str(input_details))
output_details = interpreter.get_output_details()
print(str(output_details))




data_transforms = transforms.Compose([
            transforms.Resize(112),
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])


# with tf.Session( ) as sess:
if 1:
    file_list = os.listdir(test_image_dir)

    model_interpreter_time = 0
    start_time = time.time()
    # 遍历文件
    for file in file_list:
        full_path = os.path.join(test_image_dir, file)
        image = Image.open(full_path)
        image = image.resize((112, 112))

        image_np_expanded = data_transforms(image).unsqueeze(0)

      

        # 填装数据
        model_interpreter_start_time = time.time()
        interpreter.set_tensor(input_details[0]['index'], image_np_expanded)

        # 调用模型
        interpreter.invoke()
        output_data = interpreter.get_tensor(output_details[0]['index'])
        model_interpreter_time += time.time() - model_interpreter_start_time

        # 出来的结果去掉没用的维度
        result = np.squeeze(output_data)
        print('result:{}'.format(result))
        
    used_time = time.time() - start_time
    print('used_time:{}'.format(used_time))
    print('model_interpreter_time:{}'.format(model_interpreter_time))

存在的问题:

1.目前代码转完的tflite输入是nchw。并不是nhwc。如果在cpu上部署可能会影响推理时间。

  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值