代码复现如下:
import torch
import torch.nn as nn
import torchvision
import os
model = torchvision.models.mobilenet_v2(pretrained=False)
#Define a classification head for 10 classes.
model.classifier[1] = nn.Linear(1280, 10)
model.load_state_dict(torch.load('./models/mobilenetv2_base_ckpt')['model_state_dict'])
model.cuda()
model.eval()
dummy_input = torch.randn(32, 3, 224, 224, device='cuda')
input_names = [ "actual_input_1" ]
output_names = [ "output1" ]
torch.onnx.export(
model,
dummy_input,
"./models/mobilenetv2_base.onnx",
verbose=False,
opset_version=13,
input_names=input_names,
output_names=output_names,
do_constant_folding = False)
print(' export onnx done ~ ')
如上代码,此时没有import tensorrt
,运行成功显示如下结果
D:\conda\envs\tensorrt\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
D:\conda\envs\tensorrt\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
warnings.warn(msg)
================ Diagnostic Run torch.onnx.export version 2.0.0 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
export onnx done ~
现在就多加1行 import tensorrt as trt
,没有报错,也没有任何输出,程序就结束了,不知道为什么。
D:\conda\envs\tensorrt\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
D:\conda\envs\tensorrt\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
warnings.warn(msg)
(tensorrt) E:\TensorRT\TensorRT-main\mycode\ptq_static>
conda环境主要如下:
pytorch 2.0.0 py3.8_cuda11.7_cudnn8_0 pytorch
pytorch-cuda 11.7 h16d0643_3 pytorch
pytorch-mutex 1.0 cuda pytorch
pytorch-quantization 2.1.2 pypi_0 pypi
requests 2.29.0 py38haa95532_0
setuptools 66.0.0 py38haa95532_0
sphinx-glpi-theme 0.3 pypi_0 pypi
sqlite 3.41.2 h2bbff1b_0
stack_data 0.2.0 pyhd3eb1b0_0
tabulate 0.9.0 pypi_0 pypi
tbb 2021.8.0 h59b6b97_0
tensorrt 8.5.1.7 pypi_0 pypi
tk 8.6.12 h2bbff1b_0
torchaudio 2.0.0 pypi_0 pypi
torchvision 0.15.0 pypi_0 pypi
tornado 6.2 py38h2bbff1b_0
tqdm 4.65.0 pypi_0 pypi
traitlets 5.7.1 py38haa95532_0
typing_extensions 4.5.0 py38haa95532_0
vs2015_runtime 14.27.29016 h5e58377_2
wcwidth 0.2.6 pypi_0 pypi
wget 3.2 pypi_0 pypi
win_inet_pton 1.1.0 py38haa95532_0
zstd 1.5.5 hd43e919_0