参考大神
本项目演示demo地址
1.在pytorch中运行letnet5
开始之前请确保正确的安装了pytorch依赖
1.1 克隆本博客演示代码,并进入letnetPy文件
git clone https://github.com/python-faker/csdn_example
cd csdn_example/tensorrtExample/lenetPy
1.2 运行lenet5.py文件通过pytorch生成lenet5.pth权重文件,具体的lenet5定义及细节都在lenet5.py文件中
python lenet5.py
可以看到如下输出
cuda device count: 1
input: torch.Size([1, 1, 32, 32])
conv1 torch.Size([1, 6, 28, 28])
pool1: torch.Size([1, 6, 14, 14])
conv2 torch.Size([1, 16, 10, 10])
pool2 torch.Size([1, 16, 5, 5])
view: torch.Size([1, 400])
fc1: torch.Size([1, 120])
lenet out shape: torch.Size([1, 10])
lenet out: tensor([[0.0950, 0.0998, 0.1101, 0.0975, 0.0966, 0.1097, 0.0948, 0.1056, 0.0992,
0.0917]], device='cuda:0', grad_fn=<SoftmaxBackward>)
1.3运行inference.py文件,将.pth权重转换为.wts形式的权重文件,供后续的tensorrt使用
python inference.py
可以再次得到如下输出,shape 是 [1,10].
cuda device count: 1
input: torch.Size([1, 1, 32, 32])
conv1 torch.Size([1, 6, 28, 28])
pool1: torch.Size([1, 6, 14, 14])
conv2 torch.Size([1, 16, 10, 10])
pool2 torch.Size([1, 16, 5, 5])
view: torch.Size([1, 400])
fc1: torch.Size([1, 120])
lenet out: tensor([[0.0950, 0.0998, 0.1101, 0.0975, 0.0966, 0.1097, 0.0948, 0.1056, 0.0992,
0.0917]], device='cuda:0', grad_fn=<SoftmaxBackward>)
2.在tensorrt中运行lenet5
注意如果有错误的话,请注意lenet5.wts文件位置是否正确
2.1将lenet5.wts权重文件移到 tensorrtExample
目录下并进行 cmake编译
cd lenetTrt/
cp ../lenetPy/lenet5.wts ../
mkdir build
cd build
cmake ..
make
可以看到如下输出
这样我们就得到了可执行的lenet文件
2.2将lenet5.wts权重文件复制进lenetTrt文件夹
注意如果有错误的话,请注意lenet5.wts文件位置是否正确
# 进入 lenetTrt文件夹
cd..
# 复制
cp ../lenet5.wts ./
2.3运行lenet文件来构建tensorrt引擎并且序列化
# 进入build文件夹即lenet可执行文件所在目录
# 具体命令的定义在 lenet.cpp中
./lenet -s
2.3反序列化引擎并进行推理
./lenet -d
得到输出
Output:
0.0949623, 0.0998472, 0.110072, 0.0975036, 0.0965564, 0.109736, 0.0947979, 0.105618, 0.099228, 0.0916792,
3.比较两个结果
pytorch与trt两个输入相同输入shape[1.1.32.32],所以绝对有着同样的输出结果
The pytorch output is
0.0950, 0.0998, 0.1101, 0.0975, 0.0966, 0.1097, 0.0948, 0.1056, 0.0992, 0.0917
The tensorrt output is
0.0949623, 0.0998472, 0.110072, 0.0975036, 0.0965564, 0.109736, 0.0947979, 0.105618, 0.099228, 0.0916792
两个输出shape都是正确的!
4.关于.wts文件格式
说明链接:wts格式解析