rknn目前只支持opset<=12,所以把版本一块带上
import onnx
from onnx.shape_inference import infer_shapes
import tflite2onnx
def onnx_as(input_size, onnx_path, onnx_save):
'''
onnx文件参数的修改
:param model_path: 需要修改的onnx文件地址
:return:
'''
model = onnx.load_model(onnx_path)
version=model.opset_import[0].version
ir_version=model.ir_version
print(model.opset_import,model.ir_version)
# 获取原模型图表信息
graph = model.graph
# 对输入进行修改
d = model.graph.input[0].type.tensor_type.shape.dim
oringe_h=d[2].dim_value
oringe_w=d[3].dim_value
d[2].dim_value =input_size[0]
d[3].dim_value = input_size[1]
for k in range(len(model.graph.output)):
out= model.graph.output[k].type.tensor_type.shape.dim
print(out[3].dim_value,out[2].dim_value,end=" ")
out[3].dim_value= int(input_size[1]/(oringe_w/out[3].dim_value))
out[2].dim_value= int(input_size[0]/(oringe_h/out[2].dim_value))
print(out[3].dim_value,out[2].dim_value)
# onnx.helper.make_graph函数:构建计算图原型,需要传入节点、图名称、输入张量信息、输出张量
graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
# 把计算图封装进模型modelproto
info_model = onnx.helper.make_model(graph)
# 进行形状推断
onnx_model = onnx.shape_inference.infer_shapes(info_model)
# 进行模型检测
onnx.checker.check_model(model)
# 进行模型保存
onnx_model.opset_import[0].version =version
onnx_model.ir_version=ir_version
onnx.save_model(onnx_model, onnx_save)
# model = onnx.load_model(onnx_save)
# model.opset_import[0].version = 11
# onnx.save_model(model, onnx_save)
if __name__ == '__main__':
onnx_path = r"/home/wyf/3gi/project/rockx_master/rv1106_auto_lane/DMS/pose/yolox_ss_pose.onnx"
onnx_save = r"model_as.onnx"
input_dst_h=416
input_dst_w=416
onnx_file_as = onnx_as([input_dst_h,input_dst_w], onnx_path, onnx_save)
# 896x512 112x64(8) 56x32(16) 28x16(32)
# 512x512 64x64 32x32 16x16