两者都是在python的GPU下执行的
代码如下:
pred_mask = torch.randn((1, 3, 512, 512)).cuda()
model = UNet(num_classes=6).eval().cuda()
state = {'net': model.state_dict()}
torch.save(state, 'unet.pth')
checkpoint = torch.load('unet.pth')
model.load_state_dict(checkpoint['net'])
for i in range(10):
with torch.no_grad():
start = time.time()
out = model(pred_mask)
print("torchtime", time.time() - start)
input_names = ["input0"]
output_names = ["output0"]
onnx_path = 'unet.onnx'
# pth转成onnx
# torch.onnx.export(model, pred_mask, onnx_path, verbose=False, input_names=input_names,
# output_names=output_names, opset_version=11)
# Load the ONNX model
onnx_model = onnx.load(onnx_path)
# Check that the IR is well formed
onnx.checker.check_model(onnx_model)
onnx.helper.printable_graph(onnx_model.graph)
output_dir = os.getcwd()
# onnx spend time
print(ort.get_device())
ort_session = ort.InferenceSession(os.path.join(output_dir, onnx_path))
x_input = pred_mask.data.cpu().numpy()
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
for i in range(10):
start = time.time()
outputs = ort_session.run([output_name], {input_name: x_input})
print("orttime", time.time() - start)
models/unet.py
**torchtime** 1.0482349395751953
torchtime 0.01594829559326172
torchtime 0.009942770004272461
torchtime 0.00997304916381836
torchtime 0.006982088088989258
torchtime 0.007978677749633789
torchtime 0.008976459503173828
torchtime **0.00797891616821289**
torchtime 0.009972572326660156
torchtime 0.007978677749633789
GPU
**orttime** 0.8287577629089355
orttime **0.14661574363708496**
orttime 0.13263440132141113
orttime 0.12367129325866699
orttime 0.12967991828918457
orttime 0.12665963172912598
orttime 0.13261842727661133
orttime 0.12366890907287598
orttime 0.13364648818969727
orttime 0.12469029426574707
Process finished with exit code 0
结果却不如预期:
时间上相差很大啊,相差居然有14倍之多。说好的加速呢。或许是c++版的提速??
如果您也遇到同样的困扰可以一起交流下。
在github上也有许多人遇到同样的问题,
https://github.com/microsoft/onnxruntime/issues/2750
https://github.com/microsoft/onnxruntime/issues/2404
https://github.com/microsoft/onnxruntime/issues/2404



4443

被折叠的 条评论
为什么被折叠?



