【Mxnet2onnx模型报错】GEMM: Dimension mismatch
1. 问题
- 由于项目的需求,需要将mobilefacenet(特征维度为512)训练好的Mxnet模型转换为onnx模型进行推理,结果得到如下错误:
2023-05-24 17:48:29.121959761 [W:onnxruntime:, graph.cc:84 MergeShapeInfo] Error merging shape info for output. 'fc1' source:{1,512,1,512} target:{1,512}. Falling back to lenient merge
2023-05-24 17:48:29.168708774 [E:onnxruntime:, sequential_executor.cc:339 Execute] Non-zero status code returned while running Gemm node. Name:'' Status Message: GEMM: Dimension mismatch, W: {512,512} K: 1 N:512
Traceback (most recent call last):
File "mxnet_to_onnx.py", line 155, in
onnx_inferred_demo()
File "mxnet_to_onnx.py", line 52, in onnx_inferred_demo
out = ort_session.run([outputs], input_feed={input_name: input_blob})
File "/usr/local/python3/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 188, in run
return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Gemm node. Name:'' Status Message: GEMM: Dimension mismatch, W: {512,512} K: 1 N:512
2. 查看网络结构
- 查看onnx模型结构的
fc1
层,结果如下图所示:
看对应的网络结构,好像也没有啥问题~~
3. 解决方法
尝试多种解决方法,本来使用的是mxnet==1.9.0,最后将mxnet包降低版本问题就顺利解决了(版本兼容性太差了~~)。具体安装版本命令如下:
pip3 install mxnet==1.8.0 -i https://pypi.tuna.tsinghua.edu.cn/simple