背景:之前我想把onnx模型从opset12变成opset12,太慌乱就没找着,最近找到了官网上有示例的,大爱onnx官网,分享给有需求没找着的小伙伴们。
1. onnx模型转换opset版本
官网示例:
-
import onnx
-
from onnx import version_converter, helper
-
# Preprocessing: load the model to be converted.
-
model_path = "path/to/the/model.onnx"
-
original_model = onnx.load(model_path)
-
print(f"The model before conversion:\n{original_model}")
-
# A full list of supported adapters can be found here:
-
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
-
# Apply the version conversion on the original model
-
converted_model = version_converter.convert_version(original_model, <int target_version>)
-
print(f"The model after conversion:\n{converted_model}")
其github地址如下:
onnx/docs/PythonAPIOverview.md at main · onnx/onnx (github.com)编辑https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#converting-version-of-an-onnx-model-within-default-domain-aionnx其小伙伴拉到gitee上的地址如下(以防有的小伙伴github打不开):
-
import onnx
-
from onnx import version_converter, helper
-
# A full list of supported adapters can be found here:
-
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
-
# Apply the version conversion on the original model
-
# Preprocessing: load the model to be converted.
-
model_path = r"./demo.onnx"
-
original_model = onnx.load(model_path)
-
print(f"The model before conversion:\n{original_model}")
-
converted_model = version_converter.convert_version(original_model, 11)
-
print(f"The model after conversion:\n{converted_model}")
-
save_model = model_path[:-5] + "_opset11.onnx"
-
onnx.save(converted_model, save_model)
2. onnx模型转固定动态输入尺寸
-
def change_dynamic_input_shape(model_path, shape_list: list):
-
"""
-
将动态输入的尺寸变成固定尺寸
-
Args:
-
model_path: onnx model path
-
shape_list: [1, 3, ...]
-
Returns:
-
"""
-
import os
-
import onnx
-
model_path = os.path.abspath(model_path)
-
output_path = model_path[:-5] + "_fixed.onnx"
-
model = onnx.load(model_path)
-
# print(onnx.helper.printable_graph(model.graph))
-
inputs = model.graph.input # inputs是一个列表,可以操作多输入~
-
# look_input = inputs[0].type.tensor_type.shape.dim
-
# print(look_input)
-
# print(type(look_input))
-
# inputs[0].type.tensor_type.shape.dim[0].dim_value = 1
-
for idx, i_e in enumerate(shape_list):
-
inputs[0].type.tensor_type.shape.dim[idx].dim_value = i_e
-
# print(onnx.helper.printable_graph(model.graph))
-
onnx.save(model, output_path)
-
if __name__ == "__main__":
-
model_path = "./demo.onnx"
-
shape_list = [1]
-
change_dynamic_input_shape(model_path, shape_list)