pytorch模型转tflite【以EfficientNet-BTS为例】

步骤

使用pytorch转tflite需要经过:pytorch -> onnx -> tensorflow -> tflite

配置环境

# ONNX-TensorFlow:  1.8.0   [pip install onnx-tf==1.8.0]
# ONNX:             1.8.0   [pip install onnx==1.8.0]
## TensorFlow:      2.4.0   [pip install tensorflow==2.4.0]
# tf-nightly:       2.9.0-dev20220223   [pip install tf-nightly]
# PyTorch:          1.8.0   [pip install torch==1.8.0 ]

环境配置上的一些问题:

  • 使用Tensorflow 2.4.0 会在onnx导出pb文件时报错,参考链接。应当使用tf-nightly。issue中推荐使用tf-nightly 2.4.0,测试发现使用最新版本2.9.0也可以解决问题。
  • 使用Pytorch 1.7.0 时会出现Cat等冗余op维度不匹配的问题。导出的onnx模型无法正确inference。使用Pytorch1.8可以规避这个问题
  • onnx与tf的版本对应可以参考链接

Pytorch转onnx

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
from onnxsim import simplify
import onnxruntime as ort
import numpy as np


if __name__ == '__main__':
    model = Model()
    # Converting model to ONNX
    for _ in model.modules():
        _.training = False

    test_arr = np.random.randn(1, 3, 480, 640).astype(np.float32)
    sample_input = torch.tensor(test_arr)
    # sample_input = torch.randn(1, 3, 480, 640)
    input_nodes = ['input']
    output_nodes = ['output']

    model(sample_input)

    torch.onnx.export(model, sample_input, "model.onnx", export_params=True, input_names=input_nodes,
                      output_names=output_nodes, opset_version=11)
  • 此处注意opset_version=11,如果设置opset_version=10 / 9 会出现一些op不支持的问题,例如upsample_bilinear。
  • 模型输入大小应当与原始模型输入大小一致,如果想动态适应,可以修改export中dynamic_axis参数
  • Gpu应当设置为不可用,使得全部导出过程在CPU上运行。

onnx模型测试

    model = onnx.load("model.onnx")
    ort_session = ort.InferenceSession('model.onnx')
    onnx_outputs = ort_session.run(None, {'input': test_arr})
    print('Export ONNX!')
  • 如果可以正常通过,证明onnx可以正确导出。
  • 测试时可以和原模型输出对照一下,观察是否存在误差。

onnx模型简化

    onnx_model = onnx.load("model.onnx")
    model_simp, check = simplify(onnx_model)
    assert check, "Simplified ONNX model could not be validated"
  • 模型简化使用的是onnx-simplify工具
  • 模型简化可以去除一些在模型转化过程中产生的冗余Op,例如Concat / SUB 

onnx转tensorflow

    output = prepare(model_simp)
    output.export_graph("tf_model/")
    print('Export tf_model!')
  • onnx转Tensorflow过程中可能会遇到一些Op无法转化的问题,例如interpolate函数,align_corners应当设置为True,然后重新导出onnx。参考链接

tensorflow转tflite

    converter = tf.lite.TFLiteConverter.from_saved_model("tf_model")
    tflite_model = converter.convert()
    open("model.tflite", "wb").write(tflite_model)
    print('Export tf lite model!')
  • 转换时候可能会存在一些问题。安装tf-nightly可以解决。

Onnx和Tflite模型可以通过Netron工具可视化查看。

  • 1
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值