用自定义c++运算符扩展TorchScript

Extending TorchScript with Custom C++ Operators的要点

PyTorch 1.0的发布,为 PyTorch 引入了一个新的编程模型,称为 TorchScript。TorchScript 是 Python 的一个子集,它可以由TorchScript 的编译器进行解析、编译和优化。进一步地,编译后的TorchScript 模型可以序列化为磁盘上的文件格式,该文件格式的推理,可以用纯 c + + (以及 Python)加载并运行。

TorchScript 支持torch包提供的大量运算操作,可以轻易将许多复杂模型表示为PyTorch 标准库中的一系列张量操作。不过,可能会需要扩展TorchScript 来自定义的 c + + 或 CUDA 函数。

下面给出了一个用 c + +编写TorchScript 自定义程序,实现OpenCV 的功能

1. 在 c + + 中实现自定义运算符

下面 实现 warpPerspective 函数(对图像进行透视变换),从 OpenCV 转化到 TorchScript 自定义运算符

第一步是在 c + + 中编写自定义运算符:

op.cpp

#include <opencv2/opencv.hpp>
#include <torch/script.h>

torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
  // BEGIN image_mat
  cv::Mat image_mat(/*rows=*/image.size(0),
                    /*cols=*/image.size(1),
                    /*type=*/CV_32FC1,
                    /*data=*/image.data_ptr<float>());
  // END image_mat

  // BEGIN warp_mat
  cv::Mat warp_mat(/*rows=*/warp.size(0),
                   /*cols=*/warp.size(1),
                   /*type=*/CV_32FC1,
                   /*data=*/warp.data_ptr<float>());
  // END warp_mat

  // BEGIN output_mat
  cv::Mat output_mat;
  cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});
  // END output_mat

  // BEGIN output_tensor
  torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8});
  return output.clone();
  // END output_tensor
}

在文件的顶部,把 OpenCV 的头文件 opencv2/opencv.hpp和 torch/script.h 都include进来,script.h 里面公开了 PyTorch 的 c + + API 中的所有必需的东西

warp 透视函数接受两个参数: 一个是输入image,另一个是应用于图像变换的矩阵warp。输入、返回的类型是 torch::Tensor,

注意:

TorchScript 编译器理解 torch: : : Tensor,torch: : : Scalar,double,int64t 和 std: : vector 这些类型。浮点数只支持 double不支持 float,整型只支持 int64_t,不支持 int、 short 或 long 等。

讲解:

首先将 PyTorch 张量转换为 OpenCV 矩阵,因为 OpenCV 的 warpPerspective 期望 cv: : Mat 作为输入。

为了不需要复制任何数据:

cv::Mat warp_mat(/*rows=*/warp.size(0),
                 /*cols=*/warp.size(1),
                 /*type=*/CV_32FC1,
                 /*data=*/warp.data_ptr<float>());

接下来,在TorchScript 中调用 OpenCV的warpPerspective。为此,传入image mat 和 warp mat ,和一个空的output mat。另外指定了希望输出矩阵(图像)的dsize为8 x 8:

cv::Mat output_mat;
cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{8, 8});

自定义运算符实现的最后一步是将output _ mat 转换到PyTorch张量,这样就可以在 PyTorch 中进一步使用它。这与我们之前在另一个l例子上所做的转换一致:PyTorch提供了一个 torch: :from_blob 方法,torch: :from_blob的调用如下:

torch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{8, 8});
return output.clone();

在 OpenCV 的Mat 类上使用 ptr()方法,以获得一个指向底层数据的原始指针。from_blob所输出的 torch::Tensor是一个 torch: : Tensor,指向 OpenCV 矩阵所拥有的内存。

对张量进行clone()以执行基础数据的内存副本。原因是 torch: from_blob 返回了一个没有其数据的张量,该数据仍然属于 OpenCV Mat。cv::Mat output_mat是局部变量,在函数结束时释放,所以调用Clone()返回一个新张量,其中包含新张量所拥有的原始数据的副本。

2. 使用TorchScript 注册自定义运算符

现在已经在 c + + 中实现了自定义运算符,然后需要注册它。注册语法与 pybind11语法非常相似:

TORCH_LIBRARY(my_ops, m) {
  m.def("warp_perspective", warp_perspective);
}

在 op.cpp 一开始的某个地方。TORCH_LIBRARY 宏创建一个在程序启动时调用的函数。库的名称(my_ops)作为第一个参数(不需要引号)。第二个参数(m)定义了一个类型为 torch: : Library 的变量,是注册运算符的主接口。def 方法实际上创建了一个名为 warp _ perspective 的操作符,将其暴露给 Python 和 TorchScript。通过对 def 进行多次调用,可以定义任意多个运算符。

3. 编译自定义运算符

有多种方法可以build,这里使用 Setuptools。完全从 Python 编译自定义运算符。优点:setuptools 有一个非常强大和广泛的接口,可以用来构建用 c + + 编写的 Python 模块。但是,setuptools 实际上是用于编译 Python的模块,只需要一个 setup.py 文件:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

setup(
    name="warp_perspective",
    ext_modules=[
        CppExtension(
            "warp_perspective",
            ["example_app/warp_perspective/op.cpp"],
            libraries=["opencv_core", "opencv_imgproc"],
        )
    ],
    cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
)

注意,末尾 BuildExtension 中启用了 no_python_abi_suffix ,表示 setuptools 在生成的共享库名称中省略任何特定于 Python-3的 ABI 后缀。例如,在 Python 3.7中,这个库可能被命名为 warp_perspective.cpython-37m-x86_64-linux-gnu。其中cpython-37m-x86_64-linux-gnu 是 ABI 标记,但我们实际上只希望它被命名为 warp_perspective。所以如果我们现在在 setup.py 所在的文件夹中运行

python setup.py build develop

得到

running build
running build_ext
building 'warp_perspective' extension
creating build
creating build/temp.linux-x86_64-3.7

...

Installed /warp_perspective
Processing dependencies for warp-perspective==0.0.0
Finished processing dependencies for warp-perspective==0.0.0

生成一个叫warp_perspective 的共享库。然后将其传递到 torch.ops.load_library,使运算符对于 TorchScript 可见:

torch.ops.load_library("warp_perspective.so")
print(torch.ops.custom.warp_perspective)
# <built-in method custom::warp_perspective of PyCapsule object at 0x7ff51c5b7bd0>
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值