pytorch 提取权重_利用Pytorch的C++前端(libtorch)读取预训练权重并进行预测

本文介绍了如何使用PyTorch的C++前端库Libtorch进行预训练模型的权重提取和预测。首先,讨论了PyTorch 1.0版本的C++生态支持,然后详细阐述了如何编译和使用Libtorch,包括官方库的下载、源码编译的步骤以及可能遇到的问题。最后,展示了如何结合OpenCV读取图像并进行实时预测,实现了从摄像头获取帧数据,转换为Tensor,输入模型进行手势识别的过程。
摘要由CSDN通过智能技术生成

cecbb829925ae5dc3514e9009d29d764.png

前言

距离发布Pytorch-1.0-Preview版的发布已经有两个多月,Pytorch-1.0最瞩目的功能就是生产的大力支持,推出了C++版本的生态端(FB之前已经在Detectron进行了实验),包括C++前端和C++模型编译工具。

对于我们来说,之后如果想要部署深度学习应用的时候,只需要在Python端利用Pytorch进行训练,然后使用torch.jit导出我们训练好的模型,再利用C++端的Pytorch读取进行预测即可,当然C++端的Pytorch也是可以进行训练的。

因为我们使用的C++版的Pytorch实际上为编译好的动态链接库和头文件,官方提供已经编译好的下载包:

698f7d88bc2f0d37790b7157ca0f1fb3.png

通过这个小教程我们可以了解到这个库的基本用法。

下图是利用Libtorch + OpenCV-4.0.0在GPU端进行的预测(简单识别手势),所使用的语言为C++,相较python版本的预测速度提升10%。

95fc9236a1036527766faf19d55860dc.png

好了,废话不多少,接下来聊聊如何使用它吧~

正式开始

Pytorch-1.0已经发布两个月了,为什么今天才进行尝试呢——原因很简单,个人比较担心其接口的不稳定性,故稍微多等乐些时间再进行尝试。虽然多等了,但是资料依然很是匮乏,官方的相关教程少之可怜,唯一参考信息的获取只有少数的博客和github上的issue了。

但是有一点好消息,相比于之前,现在尝试libtorch已经几乎没什么问题了,各方面都已经完善,如果大家对libtorch感兴趣,那么这篇文章就比较适合你啦~

另外还有个消息,Pytorch-1.0的稳定版将在这个星期五发布,也就是明天:

7798acdf3b28d44a6c4eef46293e9b25.png

这样下来,libtorch的接口已经基本稳定,剩下的就让我们感觉尝尝鲜吧。

获取libtorch

获取libtorch的方式有两种:

我这里推荐第二种,因为官方编译好的版本为了兼容性,选择了旧式的C++-ABI(相关链接:https://github.com/pytorch/pytorch/issues/13541 ; https://discuss.pytorch.org/t/issues-linking-with-libtorch-c-11-abi/29510),如果你使用的gcc版本>5,那么如果你将libtorch与其他编译好的库(使用gcc-5以及以上)进行联合编译,很有可能出现冲突,为了避免环境上面的问题,建议自己对源码进行编译。当然大家也可以测试下官方的

当然还有一点需要说明,如果你仅仅只单独使用libtorch库(从官方下载,并没有链接其他库,例如opencv),那么你这样编译那么是没有任何问题的。大家可以直接下载官方编译好的包进行快速尝试。

源码编译

安装好所有的依赖件后,我们下载好官方的源码,然后进入Pytorch源码目录环境执行:

git submodule update --init --recursive # 执行更新第三方库,确保安装成功

mkdir build

cd build

python ../tools/build_libtorch.py

有个ISSUE提到必须将源码目录中tools/build_pytorch_libs.sh第127行左右添加一句(-D_GLIBCXX_USE_CXX11_ABI=1)再进行编译:

THIRD_PARTY_DIR="$BASE_DIR/third_party"

C_FLAGS="" # 添加上 -D_GLIBCXX_USE_CXX11_ABI=1.

# Workaround OpenMPI build failure

# ImportError: /build/pytorch-0.2.0/.pybuild/pythonX.Y_3.6/build/torch/_C.cpython-36m-x86_64-linux-gnu.so: undefined symbol: _ZN3MPI8Datatype4FreeEv

# https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=686926

C_FLAGS="${C_FLAGS} -DOMPI_SKIP_MPICXX=1"

LDFLAGS=""

这个其实并不需要,我们直接编译即可。

这一部其实类似于Pytorch的源码编译,至于其中的细节(cuda、cudnn版本)这里不进行赘述了,大家可以查阅本站相关内页或者根据网上教程来进行安装:

如果编译无错之后我们会看到输出信息:

-- Install configuration: "Release"

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值