torch转onnx
import torch
from unet import UNet
import onnx
from onnx import shape_inference
weight_path = 'checkpoints/checkpoint_epoch5.pth'
onnx_path = 'best.onnx'
model = UNet(n_channels=3, n_classes=12)
state_dict = torch.load(weight_path)
del state_dict['mask_values']
model.load_state_dict(state_dict)
model.eval()
input = torch.randn(1, 3, 224, 224) # 网络输入
input_names = ["input"]
output_names = ["output"]
torch.onnx.export(model, input, onnx_path, input_names=input_names, output_names=output_names, opset_version=11)
# 显示网络输入
# model = onnx.load_model(onnx_path)
# onnx.save(onnx.shape_inference.infer_shapes(model), onnx_path)
keras转onnx
import keras2onnx
import onnx
from keras.models import load_model
from onnx import shape_inference
model_path = 'best.h5'
model = load_model(model_path)
# print(model.name)
onnx_model = keras2onnx.convert_keras(model, model.name)
save_file = 'best.onnx'
#
onnx.save_model(onnx_model, save_file)
model = onnx.load_model(save_file)
# print(model.graph.input[0].type.tensor_type.shape.dim[0].dim_param)
# 将动态图转化为静态图
model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1
model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 1
onnx.save(onnx.shape_inference.infer_shapes(model), save_file)
tf转onnx
python -m tf2onnx.convert --saved-model ./checkpoints/yolov4.tf --output model.onnx --opset 11 --verbose