import tensorrt 之后torch.onnx.export无法导出正常结果

代码复现如下:

import torch
import torch.nn as nn
import torchvision
import os

model = torchvision.models.mobilenet_v2(pretrained=False)
#Define a classification head for 10 classes.
model.classifier[1] = nn.Linear(1280, 10)
  
model.load_state_dict(torch.load('./models/mobilenetv2_base_ckpt')['model_state_dict'])
model.cuda()
model.eval()
dummy_input = torch.randn(32, 3, 224, 224, device='cuda')
input_names = [ "actual_input_1" ]
output_names = [ "output1" ]
torch.onnx.export(
    model,
    dummy_input,
    "./models/mobilenetv2_base.onnx",
    verbose=False,
    opset_version=13,
    input_names=input_names,
    output_names=output_names,
    do_constant_folding = False)
print(' export onnx done ~ ')

如上代码,此时没有import tensorrt ,运行成功显示如下结果

D:\conda\envs\tensorrt\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
D:\conda\envs\tensorrt\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)
================ Diagnostic Run torch.onnx.export version 2.0.0 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

 export onnx done ~

现在就多加1行 import tensorrt as trt,没有报错,也没有任何输出,程序就结束了,不知道为什么。

D:\conda\envs\tensorrt\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
 warnings.warn(
D:\conda\envs\tensorrt\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
 warnings.warn(msg)

(tensorrt) E:\TensorRT\TensorRT-main\mycode\ptq_static>

conda环境主要如下:

pytorch                   2.0.0           py3.8_cuda11.7_cudnn8_0    pytorch
pytorch-cuda              11.7                 h16d0643_3    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytorch-quantization      2.1.2                    pypi_0    pypi
requests                  2.29.0           py38haa95532_0
setuptools                66.0.0           py38haa95532_0
sphinx-glpi-theme         0.3                      pypi_0    pypi
sqlite                    3.41.2               h2bbff1b_0
stack_data                0.2.0              pyhd3eb1b0_0
tabulate                  0.9.0                    pypi_0    pypi
tbb                       2021.8.0             h59b6b97_0
tensorrt                  8.5.1.7                  pypi_0    pypi
tk                        8.6.12               h2bbff1b_0
torchaudio                2.0.0                    pypi_0    pypi
torchvision               0.15.0                   pypi_0    pypi
tornado                   6.2              py38h2bbff1b_0
tqdm                      4.65.0                   pypi_0    pypi
traitlets                 5.7.1            py38haa95532_0
typing_extensions         4.5.0            py38haa95532_0
vs2015_runtime            14.27.29016          h5e58377_2
wcwidth                   0.2.6                    pypi_0    pypi
wget                      3.2                      pypi_0    pypi
win_inet_pton             1.1.0            py38haa95532_0
zstd                      1.5.5                hd43e919_0
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值