torch与onnxruntime运行时间对比

两者都是在python的GPU下执行的
代码如下:

 pred_mask = torch.randn((1, 3, 512, 512)).cuda()
  model = UNet(num_classes=6).eval().cuda()
   state = {'net': model.state_dict()}
   torch.save(state, 'unet.pth')
   
   checkpoint = torch.load('unet.pth')
   model.load_state_dict(checkpoint['net'])
   
   for i in range(10):
       with torch.no_grad():
           start = time.time()
           out = model(pred_mask)
           print("torchtime", time.time() - start)
    
   input_names = ["input0"]
   output_names = ["output0"]
   onnx_path = 'unet.onnx'
   # pth转成onnx
   # torch.onnx.export(model, pred_mask, onnx_path, verbose=False, input_names=input_names,
   #                   output_names=output_names, opset_version=11)

   # Load the ONNX model
   onnx_model = onnx.load(onnx_path)
   # Check that the IR is well formed
   onnx.checker.check_model(onnx_model)
   onnx.helper.printable_graph(onnx_model.graph)
   output_dir = os.getcwd()
   # onnx spend time
   print(ort.get_device())
   ort_session = ort.InferenceSession(os.path.join(output_dir, onnx_path))
   x_input = pred_mask.data.cpu().numpy()
   input_name = ort_session.get_inputs()[0].name
   output_name = ort_session.get_outputs()[0].name
   for i in range(10):
       start = time.time()
       outputs = ort_session.run([output_name], {input_name: x_input})
       print("orttime", time.time() - start)
models/unet.py
**torchtime** 1.0482349395751953
torchtime 0.01594829559326172
torchtime 0.009942770004272461
torchtime 0.00997304916381836
torchtime 0.006982088088989258
torchtime 0.007978677749633789
torchtime 0.008976459503173828
torchtime **0.00797891616821289**
torchtime 0.009972572326660156
torchtime 0.007978677749633789
GPU
**orttime** 0.8287577629089355
orttime **0.14661574363708496**
orttime 0.13263440132141113
orttime 0.12367129325866699
orttime 0.12967991828918457
orttime 0.12665963172912598
orttime 0.13261842727661133
orttime 0.12366890907287598
orttime 0.13364648818969727
orttime 0.12469029426574707

Process finished with exit code 0

结果却不如预期:
时间上相差很大啊,相差居然有14倍之多。说好的加速呢。或许是c++版的提速??
如果您也遇到同样的困扰可以一起交流下。
在github上也有许多人遇到同样的问题,

https://github.com/microsoft/onnxruntime/issues/2750
https://github.com/microsoft/onnxruntime/issues/2404
https://github.com/microsoft/onnxruntime/issues/2404

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值