pytorch模型转tflite
参考文档:
1.https://blog.csdn.net/computerme/article/details/84144930
2.https://blog.csdn.net/qq_40600539/article/details/123142541
配置环境:
# tensorflow 2.4.0
# onnx 1.8.0
# onnx-tensorflow 1.8.0 [onnx-tf]
# tf-nightly 2.9.0
# pytorch 1.8.0
参考代码
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
from onnxsim import simplify
import onnxruntime as ort
import numpy as np
import torch.nn as nn
import torch
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, 2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2))
conv2 = nn.Sequential(
nn.Conv2d(32, 64, 3, 1, groups=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2))
self.feature = nn.Sequential(conv1, conv2)
self.init_weights()
def forward(self, x):
return self.feature(x)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight.data, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if __name__ == '__main__':
model = Model()
# Converting model to ONNX
for _ in model.modules():
_.training = False
test_arr = np.random.randn(1, 3, 480, 640).astype(np.float32)
sample_input = torch.tensor(test_arr)
# sample_input = torch.randn(1, 3, 480, 640)
input_nodes = ['input']
output_nodes = ['output']
model(sample_input)
torch.onnx.export(model, sample_input, "model.onnx", export_params=True, input_names=input_nodes,
output_names=output_nodes, opset_version=11)
model = onnx.load("model.onnx")
ort_session = ort.InferenceSession('model.onnx')
onnx_outputs = ort_session.run(None, {'input': test_arr})
print('Export ONNX!')
onnx_model = onnx.load("model.onnx")
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
output = prepare(model_simp)
output.export_graph("tf_model/")
print('Export tf_model!')
converter = tf.lite.TFLiteConverter.from_saved_model("tf_model")
tflite_model = converter.convert()
open("model.tflite", "wb").write(tflite_model)
print('Export tf lite model!')