一个输入:
input_tensor = torch.randn([1, 3, 256, 512])
print ("Exporting to ONNX: ", onnx_save_name)
torch_onnx_out = torch.onnx.export(model, input_tensor, onnx_save_name,
export_params=True,
verbose=True,
input_names=['label'],
output_names=["synthesized"],
opset_version=11)
多个输入:
input_tensor = torch.randn([1, 3, 256, 512])
mask_tensor = torch.randn([1, 3, 256, 512])
print ("Exporting to ONNX: ", onnx_save_name)
torch_onnx_out = torch.onnx.export(model, (input_tensor,mask_tensor), onnx_save_name,
export_params=True,
verbose=True,
input_names=['label','mask'],
output_names=["synthesized"],
opset_version=11)