给一段python代码 可以查看.onnx文件的所有节点。
import onnx
def print_graph_nodes(model_path):
# 加载 ONNX 模型
model = onnx.load(model_path)
# 遍历所有图节点并打印节点信息
for node in model.graph.node:
node_type = node.op_type
node_name = node.name
print(f'Node Type: {node_type}, Node Name: {node_name}')
if __name__ == '__main__':
onnx_model_file = 'path/to/your/model.onnx'
print_graph_nodes(onnx_model_file)
给一段python代码 可以查看.trt文件的所有节点
import tensorrt as trt
def print_network_nodes(trt_engine_path):
# 加载TensorRT引擎
with open(trt_engine_path, 'rb') as f, trt.Runtime(trt.Logger()) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
# 遍历所有网络层并打印节点信息
for layer in engine:
layer_type = layer.type
layer_name = layer.name
print(f'Layer Type: {layer_type}, Layer Name: {layer_name}')
if __name__ == '__main__':
trt_engine_file = 'path/to/your/model.trt'
print_network_nodes(trt_engine_file)