使用 C++, CUDA 扩展 PyTorch
准备工作
请先确定一下 PyTorch 版本:
>>> torch.__version__
‘1.9.1’
>>> torch.version.cuda
‘11.1’
>>> torch.backends.cudnn.version()
8005
实验环境:
用 C++ 扩展 PyTorch
官网上案例使用的算子大多和 python 的接口差不多,而有的时候我们需要操作 Tensor 的具体元素,这时候就要用到 Using accessors 那一小节介绍的技术
这里我们在人尽皆知的快速排序上进行实验:
首先是我们的源程序:
quickSort.cpp
#include <torch/extension.h>
void quickSort(torch::Tensor& src, int begin, int end) {
auto src_a = src.accessor<int, 1>();
if (begin < end) {
auto key = src_a[begin];
int i = begin;
int j = end;
while (i < j) {
while (i < j && src_a[j] > key)
--j;
if (i < j) {
src_a[i] = src_a[j];
++i;
}
while (i < j && src_a[i] < key)
++i;
if (i < j) {
src_a[j] = src_a[i];
--j;
}
}
src_a[i] = key;
quickSort(src, begin, i-1);
quickSort(src, i+1, end);
}
}
void QuickSort(torch::Tensor& src) {
quickSort(src, 0, src.size(0));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m