小白科研笔记:理解PointNet++中的three_interpolate前向计算和反向求导

1. 前言

处于科研需要,我需要理解PointNet++中的Interpolated (propogated) features,即特征插值计算过程(如果不需要当然就不去理解它啦)。它的公式所示:

在这里插入图片描述
它在代码中的实现如下所示:

# known 表示已知点的位置信息 [m,4]
# known_feats 表示已知点的特征信息 [m,C]
# unknown 表示需要插值点的位置信息 [n,4],一般来所,n>m
# interpolated_feats 表示需要插值点的特征信息 [n,C],这是返回结果
def nearest_neighbor_interpolate(unknown, known, known_feats):
    """
    :param pts: (n, 4) tensor of the bxyz positions of the unknown features
    :param ctr: (m, 4) tensor of the bxyz positions of the known features
    :param ctr_feats: (m, C) tensor of features to be propigated
    :return:
        new_features: (n, C) tensor of the features of the unknown features
    """
    # 获取 unknown 和 known 之间的近邻关系和距离信息
    dist, idx = pointnet2_utils.three_nn(unknown, known)
    # 权值是距离的倒数
    dist_recip = 1.0 / (dist + 1e-8)
    norm = torch.sum(dist_recip, dim=1, keepdim=True)
    weight = dist_recip / norm
    # 根据近邻关系以及距离信息,直接插值特征信息
    interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)

    return interpolated_feats

这篇博客主要分析以下两点:

  • pointnet2_utils.three_nn的代码细节
  • pointnet2_utils.three_interpolate的代码细节

2. three_nn代码细节

three_nn定义代码如下所示。但是three_nn只是用于找到目标点最近的三个点。也就意味着 M = 3 M=3 M=3。事实上SA-SSD中的结果实验也没有做 M = 5 , 7 , . . . M=5,7,... M=5,7,...情形下的实验。无所谓啦。注意一个小细节,就是backward计算为None,因为近邻点的寻找本身不会参与反向传播。而插值计算会有反向传播导数,这会在第四节讨论。在@staticmethod中,符号->表示返回值类型。

class ThreeNN(Function):

    @staticmethod
    def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Find the three nearest neighbors of unknown in known
        :param ctx:
        :param unknown: (N, 3)
        :param known: (M, 3)
        :return:
            dist: (N, 3) l2 distance to the three nearest neighbors
            idx: (N, 3) index of 3 nearest neighbors
        """
        # 保证输入变量内存是连续的
        assert unknown.is_contiguous()
        assert known.is_contiguous()

		# N 和 m 表示待插值点的数量和已知点的数量
        N, _ = unknown.size()
        m = known.size(0)
        dist2 = torch.cuda.FloatTensor(N, 3)
        idx = torch.cuda.IntTensor(N, 3)

        pointnet2.three_nn_wrapper(N, m, unknown, known, dist2, idx)
        return torch.sqrt(dist2), idx

    @staticmethod
    def backward(ctx, a=None, b=None):
        return None, None


three_nn = ThreeNN.apply

进一步看看three_nn_wrapper的内部结构。代码中使用pybinderpython调用cuda代码,前提是cuda代码已被setup.py文件用pip install -e .命令编译一遍,以生成pybinder可直接调用的so文件。具体细节会在第四节讨论。

// 使用 pybinder 绑定 python 中调用函数以及 cuda 实现
#include <torch/serialize/tensor.h>
#include <torch/extension.h>

#include "interpolate_gpu.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
    m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
    m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
}

可见,python函数three_nn_wrapper对应cpp函数three_nn_wrapper_fasttorch类型变量对应at::Tensor类型变量,然后把它变成cuda可以应对的float或者int类型的数组。

extern THCState *state;
void three_nn_wrapper_fast(int n, int m, at::Tensor unknown_tensor,
    at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
    // 读取数据
    const float *unknown = unknown_tensor.data<float>();
    const float *known = known_tensor.data<float>();
    float *dist2 = dist2_tensor.data<float>();
    int *idx = idx_tensor.data<int>();

    cudaStream_t stream = THCState_getCurrentStream(state);
    three_nn_kernel_launcher_fast(n, m, unknown, known, dist2, idx, stream);
}

在该cpp文件中,调用cuda内核函数three_nn_kernel_launcher_fast

__global__ void three_nn_kernel_fast(int n, int m, const float *__restrict__ unknown,
    const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
    // unknown: (N, 4)
    // known: (M, 4)
    // output:
    //      dist2: (N, 3)
    //      idx: (N, 3)


    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (pt_idx >= n) return;

    unknown += pt_idx * 4;

    dist2 += pt_idx * 3;
    idx += pt_idx * 3;

    float ub = unknown[0];
    float ux = unknown[1];
    float uy = unknown[2];
    float uz = unknown[3];

    double best1 = 1e40, best2 = 1e40, best3 = 1e40;
    int besti1 = 0, besti2 = 0, besti3 = 0;
    // 好吧,这段最小三个数查找的代码,居然是这样实现的(笑哭)
    for (int k = 0; k < m; ++k) {
        float b = known[k * 4 + 0]; //batch number
        if (b!=ub)
            continue;
        float x = known[k * 4 + 1];
        float y = known[k * 4 + 2];
        float z = known[k * 4 + 3];
        float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
        // 我也是醉了
        if (d < best1) {
            best3 = best2; besti3 = besti2;
            best2 = best1; besti2 = besti1;
            best1 = d; besti1 = k;
        }
        else if (d < best2) {
            best3 = best2; besti3 = besti2;
            best2 = d; besti2 = k;
        }
        else if (d < best3) {
            best3 = d; besti3 = k;
        }
    }
    dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
    idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
}

void three_nn_kernel_launcher_fast(int n, int m, const float *unknown,
    const float *known, float *dist2, int *idx, cudaStream_t stream) {
    // unknown: (N, 4)
    // known: (M, 4)
    // output: 
    //      dist2: (N, 3)
    //      idx: (N, 3)

    cudaError_t err;
    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK));  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    three_nn_kernel_fast<<<blocks, threads, 0, stream>>>(n, m, unknown, known, dist2, idx);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

3. Pytorch调用自定义Cuda函数细节

咱们梳理一下pytorch调用cuda函数的细节。我画了一张调用图,如下所示。

在这里插入图片描述
图1:调用关系图

调用setup.py文件对cucpp文件进行编译:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='pointnet2',
    ext_modules=[
        CUDAExtension('pointnet2_cuda', [
            'src/pointnet2_api.cpp',
            'src/interpolate.cpp', 
            'src/interpolate_gpu.cu',
        ],
        extra_compile_args={'cxx': ['-g'],
                            'nvcc': ['-O2']})
    ],
    cmdclass={'build_ext': BuildExtension}
)

4. three_interpolate代码细节

three_interpolate的具体调用过程跟three_nn一样,跟图1的调用关系是一致的,就不做过多叙述。直接看代码。有个细节注意,forward函数和backward函数的->正好是相反的。backward函数的返回是grad_features, None, None,因为只对输入已知点的特征计算导数。

class ThreeInterpolate(Function):

    @staticmethod
    def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
        """
        Performs weight linear interpolation on 3 features
        :param ctx:
        :param features: (M, C) Features descriptors to be interpolated from
        :param idx: (n, 3) three nearest neighbors of the target features in features
        :param weight: (n, 3) weights
        :return:
            output: (N, C) tensor of the interpolated features
        """
        assert features.is_contiguous()
        assert idx.is_contiguous()
        assert weight.is_contiguous()

        m, c = features.size()
        n = idx.size(0)
        ctx.three_interpolate_for_backward = (idx, weight, m)
        output = torch.cuda.FloatTensor(n, c)

        pointnet2.three_interpolate_wrapper(c, m, n, features, idx, weight, output)
        return output

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        :param ctx:
        :param grad_out: (N, C) tensor with gradients of outputs
        :return:
            grad_features: (M, C) tensor with gradients of features
            None:
            None:
        """
        idx, weight, m = ctx.three_interpolate_for_backward
        n, c = grad_out.size()

		# 对 输入特征导数 做初始化
        grad_features = Variable(torch.cuda.FloatTensor(m, c).zero_())
        grad_out_data = grad_out.data.contiguous()

        pointnet2.three_interpolate_grad_wrapper( c, n, m, grad_out_data, idx, weight, grad_features.data)
        return grad_features, None, None

先看前向计算部分。函数three_interpolate_wrapper最终调用函数three_interpolate_kernel_fast,代码如下所示:

__global__ void three_interpolate_kernel_fast(int c, int m, int n, const float *__restrict__ points,
    const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
    // points: (M, C)
    // idx: (N, 3)
    // weight: (N, 3)
    // output:
    //      out: (N, C)


    int c_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (c_idx >= c || pt_idx >= n) return;

    weight += pt_idx * 3;
    //points += c_idx * m;

    idx += pt_idx * 3;

    out += pt_idx * c;

	// 计算比较直接
    out[c_idx] = weight[0] * points[idx[0] * c + c_idx] + weight[1] * points[idx[1] * c + c_idx] + weight[2] * points[idx[2] * c + c_idx];
}

void three_interpolate_kernel_launcher_fast(int c, int m, int n,
    const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) {
   // points: (M, C)
    // idx: (N, 3)
    // weight: (N, 3)
    // output:
    //      out: (N, C)

    cudaError_t err;
    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);
    three_interpolate_kernel_fast<<<blocks, threads, 0, stream>>>(c, m, n, points, idx, weight, out);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

然后再看反向传播部分。函数three_interpolate_grad_wrapper将调用函数``,代码如下所示:

__global__ void three_interpolate_grad_kernel_fast(int c, int n, int m, const float *__restrict__ grad_out,
    const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
    // grad_out: (N, C)
    // weight: (N, 3)
    // idx: (N, 3)
    // output:
    //      grad_points: (M, C)


    int c_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (c_idx >= c || pt_idx >= n) return;
    
    grad_out += pt_idx * c + c_idx;
    weight += pt_idx * 3;
    //grad_points += c_idx * m;
    idx += pt_idx * 3;

    atomicAdd(grad_points + idx[0] * c + c_idx, grad_out[0] * weight[0]);
    atomicAdd(grad_points + idx[1] * c + c_idx, grad_out[0] * weight[1]);
    atomicAdd(grad_points + idx[2] * c + c_idx, grad_out[0] * weight[2]);
}

void three_interpolate_grad_kernel_launcher_fast(int c, int n, int m, const float *grad_out,
    const int *idx, const float *weight, float *grad_points, cudaStream_t stream) {
    // grad_out: (N, C)
    // weight: (N, 3)
    // idx: (N, 3)
    // output:
    //      grad_points: (M, C)

    cudaError_t err;
    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);
    three_interpolate_grad_kernel_fast<<<blocks, threads, 0, stream>>>(c, n, m, grad_out, idx, weight, grad_points);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

反向导数计算过程我不是特别懂。我猜grad_out应该是下一层网络的导数,即误差 e e e对插值后特征 f i ˉ \bar{f_i} fiˉ的导数,即 ∂ e / ∂ f i ˉ {\partial {e}}/{\partial \bar{f_i}} e/fiˉ。如果真的是这样,就好解释了。反向导数计算肯定是需要计算 ∂ f i ˉ / ∂ f j {\partial \bar{f_i}}/{\partial f_j} fiˉ/fj,这个计算过程在代码中做了简化:

∂ f i ˉ / ∂ f j = w j ( p i ) / ∑ i = 1 M w i ( p i ) ∝ w j ( p i ) {\partial \bar{f_i}}/{\partial f_j} = {w_j(p_i)}/{\sum_{i=1}^M w_i(p_i)} \propto w_j(p_i) fiˉ/fj=wj(pi)/i=1Mwi(pi)wj(pi)

这一层的导数(即误差 e e e对插值前特征 f j {f_j} fj的导数)计算就是:

∂ e / ∂ f j = ∂ e / ∂ f i ˉ ∗ ∂ f i ˉ / ∂ f j ∝ g r a d _ o u t ∗ w e i g h t {\partial {e}}/{\partial {f_j}} = {\partial {e}}/{\partial \bar{f_i}} * {\partial \bar{f_i}}/{\partial f_j} \propto grad\_out*weight e/fj=e/fiˉfiˉ/fjgrad_outweight

5. 结束语

挺有趣的哈哈。

  • 12
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值