Libtorch部署显著性目标检测网络BASNet
Libtorch 是Pytorch官方提供的C++ API,使用方法可以说是很大程度上还原了Pytorch。
一、环境配置
- GitHub地址:BASNet
- VS2017+CUDA 10.1+cuDNN 7.6.5+OpenCV 3.4.11
- Pytorch 1.5.0+torchvision 0.6.0
- Libtorch 1.5.0 (release):提取码:865k
https://pan.baidu.com/s/1Ty-UPZWEOnNRPwldexLZzw
Libtorch下载的版本一定要和Pytorch版本对应,我这里提供的是release版本的Libtorch
注:
1.原作者GitHub上推荐的是用Pytorch 0.4.0来跑,但是Libtorch是从Pytorch 1.0开始支持的,所以尽量升级为1.x的来运行
2.!!!Pytorch,Libtorch,CUDA,cuDNN版本一定要一一对应
3.训练模型生成.pth的Pytorch版本也最好和转换为torchscript的.pt文件最好也一致,不要用0.4.0的训练用1.7.0的转换。
二、torchscript生成.pt文件
转换之前修改BASNet.py文件(注意这个时候应该已经是训练完了,不要在训练前修改)344行 return F.sigmoid(dout), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6), F.sigmoid(db)
只return第一个值。不更改的话C++那边forward的时候会报错。似乎也能用tuple来接收多个值,这里我就没有研究了。
import torch
from model import BASNet #model文件夹下的BASNet.py文件
model_dir = r"./saved_models/basnet_bsi/basnet_best.pth"
model = BASNet(3,1) # channels, classes
model.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
model.cuda()
model.eval()
example = torch.rand(1,3,256,256).cuda() # input example
traced_script_module = torch.jit.trace(model,example)
traced_script_module.save(r"./basnet.pt")
.cuda()很重要,挂载到CUDA上,在C++那边既可以用CPU计算,也可以用GPU计算。反之则只能用CPU计算。
参考:TorchScript使用的注意事项和常见错误
三、Libtorch部署
Libtorch配置有两种方式,一种是CMake,一种是VS手动配置。这里我选择的是VS手动配置。
下载好的Libtorch文件列表如下:
1.添加环境变量
2.VS配置环境
(1)创建一个空项目
(2)添加各种路径
i.项目—》属性—》C/C++—》常规—》附加包含目录
D:\Libraries\libtorch\include
D:\Libraries\libtorch\include\torch\csrc