【代码阅读】详解在Pytorch中定义自己写的CUDA编程函数

目前,3D的网络,尤其时point-based的网络,很多模块在pytorch中都没有官方实现,这就需要我们自己写。例如PointNet++中的FPS,group,query等函数。之前也只是用过,对其的修改也限于python层面,这次,就好好探究一下,如何自定义一个函数,如何将其加入到pytorch中,使得在pytorch中也能用。

其实,这一块,有非常详细的官方文档,讲述了如何自定义一个函数,并将其放入Pytorch中。当然,如何写一个函数,我们还需要cuda编程的知识,这里就先讲外围一部分,假设我们已经写好了一个函数。官网文档的示例讲的很清楚了,这里就拿PointNet++来说明一下。如果想要详细了解的话,可以先看一下官方文档:
https://pytorch.org/tutorials/advanced/cpp_extension.html?highlight=pybind11_module

官方文档中清楚的给出了两种将自己定义的cuda编程的函数放入pytorch中的方法。一种是通过编译,生成一个python的包,一种是在程序执行中调用。

个人认为编译的方法更好一些,生成了一个python包,在其他的project中也很方便调用。

首先,我们先看一下pytorch接口的设置,这里,我们先假设已经写好了函数。

pytorch接口设置

编译的方式

这里的PointNet++版本以这个链接中的为例:
https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/5a4416f51ceaeba242828cabf39133433336850d

假设我们已经写好了要实现的函数,在本例中,函数包括pointnet2/src中的一系列xxx.cpp,xxx.cu和xxx.h。

那么我们如何放到pytorch的接口中呢?这就要看pointnet2/setup.py中:

# 这两个import是标准写法,不用改,setuptools是为了把我们自定义的函数变成一个包
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
	# 这个包的name是pointnet2
    name='pointnet2', 
    ext_modules=[
    	# 模块的name是pointnet2_cuda,就是说要import pointnet2_cuda
    	# 定义与这个包关联的xxx.cpp, xxx.cu, xxx.h
        CUDAExtension('pointnet2_cuda', [
            'src/pointnet2_api.cpp',
            
            'src/ball_query.cpp', 
            'src/ball_query_gpu.cu',
            'src/group_points.cpp', 
            'src/group_points_gpu.cu',
            'src/interpolate.cpp', 
            'src/interpolate_gpu.cu',
            'src/sampling.cpp', 
            'src/sampling_gpu.cu',
        ],
        # 以下的东西都不用改
        extra_compile_args={
   'cxx': ['-g'],
                            'nvcc': ['-O2']})
    ],
    cmdclass={
   'build_ext': BuildExtension}
)

在我们用这些函数之前,要先运行

python setup.py install

其实就是在把我们定义的这些函数,集合成一个包安装起来。这就出现了一个问题,函数包是安装上了,但我们用什么接口去调用函数呢?

这部分就定义在pointnet2/pointnet2_api.py中

#include <torch/serialize/tensor.h>
#include <torch/extension.h>

// 把写好的函数都先include进来
#include "ball_query_gpu.h"
#include "group_points_gpu.h"
#include "sampling_gpu.h"
#include "interpolate_gpu.h"

// 使用PYBIND11_MODULE,这个是在torch/extension.h中包含了的
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
   
	// python中调用时使用的函数名为:ball_query_wrapper
	// cpp中相关的函数是:ball_query_wrapper_fast
	// python中调用help所产生的提示是:"ball_query_wrapper_fast"
    m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");

    m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
    m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");

    m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
    m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");

    m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
    
    m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
    m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
    m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
}

上面就完成了pytorch中要调用的接口了。那么我们看一下,是如何调用的?

这个在pointnet2/pointnet2_utils.py中,以

import torch
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn as nn
from typing import Tuple

# import我们自己定义的包
import pointnet2_cuda as pointnet2

# 定义一个pytorch的函数,要继承torch.autograd.Function
class GatherOperation(Function):
	
	# 定义前向运算,ctx保存一些变量,保存如ctx中的变量会传入backward中
    @staticmethod
    def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
        """
        :param ctx:
        :param features: (B, C, N)
        :param idx: (B, npoint) index tensor of the features to gather
        :return:
            output: (B, C, npoint)
        """
        assert features.is_contiguous()
        assert idx.is_contiguous()

        B, npoint = idx.size()
        _, C, N = features.size()
        output = torch.cuda.FloatTensor(B, C, npoint)
		
		# 调用我们定义的函数,进行计算
        pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
		
		# 将反向传播中要用到的变量放入ctx中
        ctx.for_backwards = (idx, C, N)
        return output
	
	# 定义反向传播的函数,其输入的第一个变量是ctx,然后其他输入的数量与forward的输出的数量相同
    @staticmethod
    def backward(ctx, grad_out):
		# 从ctx中取出前向计算中保存的变量
        idx, C, N = ctx.for_backwards
        B, npoint = idx.size()

        grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
        grad_out_data = grad_out.data.contiguous()
        pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)

		# 输出变量的数量必须与forward输入的变量数量(除ctx之外)相同
        return grad_features, None

# 调用我们定义的函数的方法是outputs = xxx.apply(inputs),这里预先把apply取出来,所以用的时候就可以直接使用 outputs = gather_operation(inputs)即可
gather_operation = GatherOperation.apply

在运行是调用的形式

以PVCNN中的代码为例。PVCNN中的xxx.cpp,xxx.cu,xxx.h都modules/functional/src文件夹中。

对应编译的方式的顺序来看,先看看,xxx.cpp和xxx.cu是怎么被pytorch所知道的呢?这个在modules/backend.py中:

import os

from torch.utils.cpp_extension import load

_src_path = os.path.dirname(os.path.abspath(__file__))
_backend = load(name='_pvcnn_backend',
                extra_cflags=['-O3', '-std=c++17'],
                sources=[os
  • 16
    点赞
  • 56
    收藏
    觉得还不错? 一键收藏
  • 12
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值