在用netron查看模型时,希望看到各个节点的shape,可以执行以下代码。
1、依赖包
pip install onnx
pip install onnx_graphsurgeon --index-url https://pypi.ngc.nvidia.com
2、代码
简单点的代码:
import onnx
onnx_graph = onnx.load(input_onnx_file_path)
estimated_graph = onnx.shape_inference.infer_shapes(onnx_graph)
onnx.save(estimated_graph, output_onnx_file_path)
复杂代码:
import onnx
import onnx_graphsurgeon as gs
onnx_graph = onnx.load(input_onnx_file_path)
graph = gs.import_onnx(onnx_graph)
graph.cleanup().toposort() #从图形中删除未使用的节点和张量,并对图形进行拓扑排序
# Shape Estimation
estimated_graph = None
try:
estimated_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
except:
estimated_graph = gs.export_onnx(graph)
onnx.save(estimated_graph, output_onnx_file_path)
效果图:
原来 | 新的 |
---|---|