目前,3D的网络,尤其时point-based的网络,很多模块在pytorch中都没有官方实现,这就需要我们自己写。例如PointNet++中的FPS,group,query等函数。之前也只是用过,对其的修改也限于python层面,这次,就好好探究一下,如何自定义一个函数,如何将其加入到pytorch中,使得在pytorch中也能用。
其实,这一块,有非常详细的官方文档,讲述了如何自定义一个函数,并将其放入Pytorch中。当然,如何写一个函数,我们还需要cuda编程的知识,这里就先讲外围一部分,假设我们已经写好了一个函数。官网文档的示例讲的很清楚了,这里就拿PointNet++来说明一下。如果想要详细了解的话,可以先看一下官方文档:
https://pytorch.org/tutorials/advanced/cpp_extension.html?highlight=pybind11_module
官方文档中清楚的给出了两种将自己定义的cuda编程的函数放入pytorch中的方法。一种是通过编译,生成一个python的包,一种是在程序执行中调用。
个人认为编译的方法更好一些,生成了一个python包,在其他的project中也很方便调用。
首先,我们先看一下pytorch接口的设置,这里,我们先假设已经写好了函数。
pytorch接口设置
编译的方式
这里的PointNet++版本以这个链接中的为例:
https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/5a4416f51ceaeba242828cabf39133433336850d
假设我们已经写好了要实现的函数,在本例中,函数包括pointnet2/src中的一系列xxx.cpp,xxx.cu和xxx.h。
那么我们如何放到pytorch的接口中呢?这就要看pointnet2/setup.py中:
# 这两个import是标准写法,不用改,setuptools是为了把我们自定义的函数变成一个包
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
# 这个包的name是pointnet2
name='pointnet2',
ext_modules=[
# 模块的name是pointnet2_cuda,就是说要import pointnet2_cuda
# 定义与这个包关联的xxx.cpp, xxx.cu, xxx.h
CUDAExtension('pointnet2_cuda', [
'src/pointnet2_api.cpp',
'src/ball_query.cpp',
'src/ball_query_gpu.cu',
'src/group_points.cpp',
'src/group_points_gpu.cu',
'src/interpolate.cpp',
'src/interpolate_gpu.cu',
'src/sampling.cpp',
'src/sampling_gpu.cu',
],
# 以下的东西都不用改
extra_compile_args={
'cxx': ['-g'],
'nvcc': ['-O2']})
],
cmdclass={
'build_ext': BuildExtension}
)
在我们用这些函数之前,要先运行
python setup.py install
其实就是在把我们定义的这些函数,集合成一个包安装起来。这就出现了一个问题,函数包是安装上了,但我们用什么接口去调用函数呢?
这部分就定义在pointnet2/pointnet2_api.py中
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
// 把写好的函数都先include进来
#include "ball_query_gpu.h"
#include "group_points_gpu.h"
#include "sampling_gpu.h"
#include "interpolate_gpu.h"
// 使用PYBIND11_MODULE,这个是在torch/extension.h中包含了的
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// python中调用时使用的函数名为:ball_query_wrapper
// cpp中相关的函数是:ball_query_wrapper_fast
// python中调用help所产生的提示是:"ball_query_wrapper_fast"
m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");
m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
}
上面就完成了pytorch中要调用的接口了。那么我们看一下,是如何调用的?
这个在pointnet2/pointnet2_utils.py中,以
import torch
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn as nn
from typing import Tuple
# import我们自己定义的包
import pointnet2_cuda as pointnet2
# 定义一个pytorch的函数,要继承torch.autograd.Function
class GatherOperation(Function):
# 定义前向运算,ctx保存一些变量,保存如ctx中的变量会传入backward中
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param features: (B, C, N)
:param idx: (B, npoint) index tensor of the features to gather
:return:
output: (B, C, npoint)
"""
assert features.is_contiguous()
assert idx.is_contiguous()
B, npoint = idx.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint)
# 调用我们定义的函数,进行计算
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
# 将反向传播中要用到的变量放入ctx中
ctx.for_backwards = (idx, C, N)
return output
# 定义反向传播的函数,其输入的第一个变量是ctx,然后其他输入的数量与forward的输出的数量相同
@staticmethod
def backward(ctx, grad_out):
# 从ctx中取出前向计算中保存的变量
idx, C, N = ctx.for_backwards
B, npoint = idx.size()
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
# 输出变量的数量必须与forward输入的变量数量(除ctx之外)相同
return grad_features, None
# 调用我们定义的函数的方法是outputs = xxx.apply(inputs),这里预先把apply取出来,所以用的时候就可以直接使用 outputs = gather_operation(inputs)即可
gather_operation = GatherOperation.apply
在运行是调用的形式
以PVCNN中的代码为例。PVCNN中的xxx.cpp,xxx.cu,xxx.h都modules/functional/src文件夹中。
对应编译的方式的顺序来看,先看看,xxx.cpp和xxx.cu是怎么被pytorch所知道的呢?这个在modules/backend.py中:
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
_backend = load(name='_pvcnn_backend',
extra_cflags=['-O3', '-std=c++17'],
sources=[os