Extending TorchScript with Custom C++ Operators的要点
PyTorch 1.0的发布,为 PyTorch 引入了一个新的编程模型,称为 TorchScript。TorchScript 是 Python 的一个子集,它可以由TorchScript 的编译器进行解析、编译和优化。进一步地,编译后的TorchScript 模型可以序列化为磁盘上的文件格式,该文件格式的推理,可以用纯 c + + (以及 Python)加载并运行。
TorchScript 支持torch包提供的大量运算操作,可以轻易将许多复杂模型表示为PyTorch 标准库中的一系列张量操作。不过,可能会需要扩展TorchScript 来自定义的 c + + 或 CUDA 函数。
下面给出了一个用 c + +编写TorchScript 自定义程序,实现OpenCV 的功能
1. 在 c + + 中实现自定义运算符
下面 实现 warpPerspective 函数(对图像进行透视变换),从 OpenCV 转化到 TorchScript 自定义运算符
第一步是在 c + + 中编写自定义运算符:
op.cpp
#include <opencv2/opencv.hpp>
#include <torch/script.h>
torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
// BEGIN image_mat
cv::Mat image_mat(/*rows=*/image.size(0),
/*cols=*/image.size(1),
/*type=*/CV_32FC1,
/*data=*/image.data_ptr<float>());
// END image_mat
// BEGIN warp_mat
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data_ptr<float>());
// END warp_mat
// BEGIN output_mat
cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});
// END output_mat
// BEGIN output_tensor
torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8});
return output.clone();
// END output_tensor
}
在文件的顶部,把 OpenCV 的头文件 opencv2/opencv.hpp和 torch/script.h 都include进来,script.h 里面公开了 PyTorch 的 c + + API 中的所有必需的东西
warp 透视函数接受两个参数: 一个是输入image,另一个是应用于图像变换的矩阵warp。输入、返回的类型是 torch::Tensor,
注意:
TorchScript 编译器理解 torch: : : Tensor,torch: : : Scalar,double,int64t 和 std: : vector 这些类型。浮点数只支持 double不支持 float,整型只支持 int64_t,不支持 int、 short 或 long 等。
讲解:
首先将 PyTorch 张量转换为 OpenCV 矩阵,因为 OpenCV 的 warpPerspective 期望 cv: : Mat 作为输入。
为了不需要复制任何数据:
cv::Mat warp_mat(/*rows=*/warp.size(0),
/*cols=*/warp.size(1),
/*type=*/CV_32FC1,
/*data=*/warp.data_ptr<float>());
接下来,在TorchScript 中调用 OpenCV的warpPerspective。为此,传入image mat 和 warp mat ,和一个空的output mat。另外指定了希望输出矩阵(图像)的dsize为8 x 8:
cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});
自定义运算符实现的最后一步是将output _ mat 转换到PyTorch张量,这样就可以在 PyTorch 中进一步使用它。这与我们之前在另一个l例子上所做的转换一致:PyTorch提供了一个 torch: :from_blob 方法,torch: :from_blob的调用如下:
torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8});
return output.clone();
在 OpenCV 的Mat 类上使用 ptr()方法,以获得一个指向底层数据的原始指针。from_blob所输出的 torch::Tensor是一个 torch: : Tensor,指向 OpenCV 矩阵所拥有的内存。
对张量进行clone()以执行基础数据的内存副本。原因是 torch: from_blob 返回了一个没有其数据的张量,该数据仍然属于 OpenCV Mat。cv::Mat output_mat是局部变量,在函数结束时释放,所以调用Clone()返回一个新张量,其中包含新张量所拥有的原始数据的副本。
2. 使用TorchScript 注册自定义运算符
现在已经在 c + + 中实现了自定义运算符,然后需要注册它。注册语法与 pybind11语法非常相似:
TORCH_LIBRARY(my_ops, m) {
m.def("warp_perspective", warp_perspective);
}
在 op.cpp 一开始的某个地方。TORCH_LIBRARY 宏创建一个在程序启动时调用的函数。库的名称(my_ops)作为第一个参数(不需要引号)。第二个参数(m)定义了一个类型为 torch: : Library 的变量,是注册运算符的主接口。def 方法实际上创建了一个名为 warp _ perspective 的操作符,将其暴露给 Python 和 TorchScript。通过对 def 进行多次调用,可以定义任意多个运算符。
3. 编译自定义运算符
有多种方法可以build,这里使用 Setuptools。完全从 Python 编译自定义运算符。优点:setuptools 有一个非常强大和广泛的接口,可以用来构建用 c + + 编写的 Python 模块。但是,setuptools 实际上是用于编译 Python的模块,只需要一个 setup.py 文件:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name="warp_perspective",
ext_modules=[
CppExtension(
"warp_perspective",
["example_app/warp_perspective/op.cpp"],
libraries=["opencv_core", "opencv_imgproc"],
)
],
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
)
注意,末尾 BuildExtension 中启用了 no_python_abi_suffix ,表示 setuptools 在生成的共享库名称中省略任何特定于 Python-3的 ABI 后缀。例如,在 Python 3.7中,这个库可能被命名为 warp_perspective.cpython-37m-x86_64-linux-gnu。其中cpython-37m-x86_64-linux-gnu 是 ABI 标记,但我们实际上只希望它被命名为 warp_perspective。所以如果我们现在在 setup.py 所在的文件夹中运行
python setup.py build develop
得到
running build
running build_ext
building 'warp_perspective' extension
creating build
creating build/temp.linux-x86_64-3.7
...
Installed /warp_perspective
Processing dependencies for warp-perspective==0.0.0
Finished processing dependencies for warp-perspective==0.0.0
生成一个叫warp_perspective 的共享库。然后将其传递到 torch.ops.load_library,使运算符对于 TorchScript 可见:
torch.ops.load_library("warp_perspective.so")
print(torch.ops.custom.warp_perspective)
# <built-in method custom::warp_perspective of PyCapsule object at 0x7ff51c5b7bd0>