一. 背景介绍
当我们的测试代码有很多依赖,或者训练测试代码在一起时,如何快速地导出onnx?
二. 实现
利用测试代码中原始的数据处理,在模型测试的时候直接导出onnx,并对比torch和onnx的推理结果。
tips:设置一个export_onnx参数,直接导出onnx。
实现代码如下,一个例子:
with torch.no_grad():
export_onnx = True
if export_onnx:
input_img = img_mix.to(device)
save_path = "G.onnx"
torch.onnx.export(net_G,
input_img,
save_path,
export_params=True,
opset_version=11,
input_names=['input'],
output_names=['output'])
print("export onnx success")
G_pred = net_G(img_mix.to(device))[:, 0:3, :, :]
if export_onnx:
import onnxruntime as rt
sess = rt.InferenceSession("G.onnx")
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
pred_onnx = sess.run([output_name], {input_name:np.array(img_mix)})
# 测试torch与onnx之间的推理误差
print(abs(pred_onnx[0][:,0:3,:,:]-np.array(G_pred.cpu().detach())).max())