训练完pytorch模型后,将其转换成onnx模型:
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
"resnet18_float.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : '1'}, # variable lenght axes
'output' : {0 : '1'}}
成功转换。
但是如果网络结构中有F.avg_pool2d()即平均池化层时可能会遇到一些问题。特别是这样使用:
out = F.avg_pool2d(out, 4)
这样转换的onnx模型的平均池化层中strides是空的(可以用netron查看)。
如果用onnxruntime进行推理,则会报错:
Error in Node: : Attribute ‘strides’ is expected to have field ‘ints’
解决办法:指定stride
out = F.avg_pool2d(out, 4, stride=4)
似乎最新版本的pytorch已经解决这个问题。