onnx系列:
今天再用一个超级简单的pytorch模型转成onnx作为例子。源码参考自:https://github.com/TrojanXu/onnxparser-trt-plugin-sample
核心代码,非工程项目
1. pytorch模型定义
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel,self).__init__()
def forward(self, input, grid):
input += 1
return F.grid_sample(input, grid, mode='bilinear', padding_mode='reflection', align_corners=True)
2. graph.input, graph.output, graph.node
采用onnx载入模型,再打印 graph 的 input,output,node属性。
model = onnx.load(onnx_model_file)
print(model.graph.input)
>>>
[name: "input"
type {
tensor_type {
elem_type: 10
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_value: 1
}
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
, name: "grid"
type {
tensor_type {
elem_type: 10
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_value: 4
}
dim {
dim_value: 4
}
dim {
dim_value: 2
}
}
}
}
]
print(model.graph.output)
>>>
[name: "output"
type {
tensor_type {
elem_type: 10
shape {
dim {
dim_value: 4
}
dim {
dim_value: 1
}
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
]
print(len(model.graph.node), model.graph.node)
>>>
3
[output: "2"
name: "Constant_0"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 10
raw_data: "\000<"
}
type: TENSOR
}
, input: "input"
input: "2"
output: "3"
name: "Add_1"
op_type: "Add"
, input: "3"
input: "grid"
output: "output"
name: "GridSampler_2"
op_type: "GridSampler"
attribute {
name: "aligncorners"
i: 1
type: INT
}
attribute {
name: "interpolationmode"
i: 0
type: INT
}
attribute {
name: "paddingmode"
i: 2
type: INT
}
]
可以看到
- input是模型推理时输入的参数+conv等网络的预训练权重
- output是模型的输出
- node代表运算