DSVT 论文阅读

https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_DSVT_Dynamic_Sparse_Voxel_Transformer_With_Rotated_Sets_CVPR_2023_paper.pdf

摘要:

 高效且方便部署的稀疏点云3D backbone的设计很关键,和稀疏卷积相比,attention机制更适合长距离的关系,而且更容易在硬件上实现部署。

这个paper带来的是,Dynamic Sparse Voxel Transformer,单步长,基于窗口的网格的Transformer的backbone。

为了方便的并行的处理稀疏的点云,这里根据稀疏性,划分成很多局部的区域,然后并行的计算局部区域的特征;考虑到不同区域之间的交互,这里设计了一个旋转的区域划分策略;为了有效支持降采样和编码几何信息,这里也设计了一个基于attention的3D的pooling模块。

1. 介绍

点云的特点,稀疏而且在连续空间中不规则的分布。分析对比了当前的一些方法:

point_based的稀疏处理:ball-query 和max pooling提取特征;计算消耗太大。

voxel_based的稀疏处理:转化为规则的grid,grid上进行sparse conv;因为submanifold dilation问题,用了sparse conv逐步的sparse就变得相对dense了。

更关键的是,无法简单的部署实现,需要CUDA定制OP。

已经有很多在3D的点云中实现的transformer方案,但是由于点云的稀疏性,不好直接使用标准的transformer block。有些方法,通过随机采样的方法,或者group的方法(SST,swformer),性能不好,或者需要额外的padding操作或者无法并行的attention等浪费资源的操作。或者有一些定制cuda进行实现(Voxel set transformer, Voxel transformer for 3d object detection)。针对如上的一些问题,本文提出了高效,容易部署的方案。

两个关键模块:

dynamic sparse window attention,parallel computation of local windows with diverse sparsity

首先是voxelization,得到稀疏的voxel;

 然后在连续的self attention layer上,rotating的划分稀疏的voxel为不同的set,这些set之内进行并行的self attention;set之间的交互,通过rotating不同的划分;

        This strategy is efficient in regards to real-world latency:

        i) all the windows are processed in parallel, which is nearly independent of the voxel distribution,

        ii) using self-attention without key duplication, which facilitates memory access in hardware.

learnable 3D pooling , downsample和encode geometric information

we first convert the sparse downsampling region into dense and process an attention-style 3D pooling operation to automatically aggregate the local spatial features

2. 相关工作

2.1 点云的3D感知

基于点

基于voxel

2.2 基于Transformer的3D感知

VoTr

SST

SWFormer

和上面的介绍有点重复。

3.  方法

3.1 overview

 核心是DSVT Block。

3.2 dynamic sparse window attention

如果直接借鉴vit/swin的方式,每个window内的非空voxel数量差别很大,不好高效的并行计算。这里其实采取的是划分完window后再细分subset的方式,把颗粒度进一步变小,为了保证set长度一致,同样采用了padding的方式。

dynamic set partition

首先是划分window; window的尺寸选为LxWxH

然后是针对某个window,假设有N和非空voxel,N个非空voxel的组成如下;

然后对这N个voxel进一步划分subset, S即为需要的subset的数量

分配N个非空voxel到这些划分好的subset内,可能会有些重复,但是会被mask掉

 

 然后根据划分好的subset,其中的voxel的index,提取这个subset下的voxel的feature

 Rotated set attention

每个set之内是单独的self attention,不同set之间并没有connection,可以通过更改sorting的规则,构建不同的set,这样就实现了intra set的connection。

这里写到dsvt的block里,第一层以X轴去排序,第二层以Y轴去排序。

 

Hybrid window partition

上面解决了subset之间的信息交互,window之间的同样也需要解决。这里参考了swin transformer里的window shifting的方法。

3.3  Attention-style 3D Pooling

pillar的方法,不需要此模块,经过pfn之后,直接就是bev了。

voxel的方法,还需要通过pooling的方法,压缩到bev。

每个局部区域,lxwxh, 包含p个非空voxel;

首先pad成dense的;

然后采用标准的max pooling;

然后用这个max pooling的结果构建query vector,使用原始的作为key和val,attention得到最终的结果。

4. 代码

config文件
VFE:
    NAME: DynPillarVFE3D
    WITH_DISTANCE: False
    USE_ABSLOTE_XYZ: True
    USE_NORM: True
    NUM_FILTERS: [ 192, 192 ]

  BACKBONE_3D:
    NAME: DSVT
    INPUT_LAYER:
      sparse_shape: [468, 468, 1]
      downsample_stride: []
      d_model: [192]
      set_info: [[36, 4]]
      window_shape: [[12, 12, 1]]
      hybrid_factor: [2, 2, 1] # x, y, z
      shifts_list: [[[0, 0, 0], [6, 6, 0]]]
      normalize_pos: False

    block_name: ['DSVTBlock']
    set_info: [[36, 4]]
    d_model: [192]
    nhead: [8]
    dim_feedforward: [384]
    dropout: 0.0
    activation: gelu
    output_shape: [468, 468]
    conv_out_channel: 192
    # ues_checkpoint: True

        

 VFE:
    NAME: DynPillarVFE3D
    WITH_DISTANCE: False
    USE_ABSLOTE_XYZ: True
    USE_NORM: True
    NUM_FILTERS: [ 192, 192 ]

  BACKBONE_3D:
    NAME: DSVT
    INPUT_LAYER:
      sparse_shape: [468, 468, 32]
      downsample_stride: [[1, 1, 4], [1, 1, 4], [1, 1, 2]]
      d_model: [192, 192, 192, 192]
      set_info: [[48, 1], [48, 1], [48, 1], [48, 1]]
      window_shape: [[12, 12, 32], [12, 12, 8], [12, 12, 2], [12, 12, 1]]
      hybrid_factor: [2, 2, 1] # x, y, z
      shifts_list: [[[0, 0, 0], [6, 6, 0]], [[0, 0, 0], [6, 6, 0]], [[0, 0, 0], [6, 6, 0]], [[0, 0, 0], [6, 6, 0]]]
      normalize_pos: False
    
    block_name: ['DSVTBlock','DSVTBlock','DSVTBlock','DSVTBlock']
    set_info: [[48, 1], [48, 1], [48, 1], [48, 1]]
    d_model: [192, 192, 192, 192]
    nhead: [8, 8, 8, 8]
    dim_feedforward: [384, 384, 384, 384]
    dropout: 0.0 
    activation: gelu
    reduction_type: 'attention'
    output_shape: [468, 468]
    conv_out_channel: 192
    # ues_checkpoint: True
DynPillarVFE3D 是啥?

利用pytorch实现的动态的pillar或者voxel的feature提取,不限制pillar或者voxel的数量,不限制pillar或者voxel内部点的数量。这个是相对之前的voxel_generator hard方式,保留了更完整的信息。

class DynamicPillarVFE_3d(VFETemplate):
    def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_cloud_range, **kwargs):
        super().__init__(model_cfg=model_cfg)

        self.use_norm = self.model_cfg.USE_NORM
        self.with_distance = self.model_cfg.WITH_DISTANCE
        self.use_absolute_xyz = self.model_cfg.USE_ABSLOTE_XYZ
        num_point_features += 6 if self.use_absolute_xyz else 3
        if self.with_distance:
            num_point_features += 1

        self.num_filters = self.model_cfg.NUM_FILTERS
        assert len(self.num_filters) > 0
        num_filters = [num_point_features] + list(self.num_filters)

        pfn_layers = []
        for i in range(len(num_filters) - 1):
            in_filters = num_filters[i]
            out_filters = num_filters[i + 1]
            pfn_layers.append(
                PFNLayerV2(in_filters, out_filters, self.use_norm, last_layer=(i >= len(num_filters) - 2))
            )
        self.pfn_layers = nn.ModuleList(pfn_layers)

        self.voxel_x = voxel_size[0]
        self.voxel_y = voxel_size[1]
        self.voxel_z = voxel_size[2]
        self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
        self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
        self.z_offset = self.voxel_z / 2 + point_cloud_range[2]

        self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2]
        self.scale_yz = grid_size[1] * grid_size[2]
        self.scale_z = grid_size[2]

        self.grid_size = torch.tensor(grid_size).cuda()
        self.voxel_size = torch.tensor(voxel_size).cuda()
        self.point_cloud_range = torch.tensor(point_cloud_range).cuda()

    def get_output_feature_dim(self):
        return self.num_filters[-1]

    def forward(self, batch_dict, **kwargs):
        points = batch_dict['points'] # (batch_idx, x, y, z, i, e)

        points_coords = torch.floor((points[:, [1,2,3]] - self.point_cloud_range[[0,1,2]]) / self.voxel_size[[0,1,2]]).int()
        mask = ((points_coords >= 0) & (points_coords < self.grid_size[[0,1,2]])).all(dim=1)
        points = points[mask]
        points_coords = points_coords[mask]
        points_xyz = points[:, [1, 2, 3]].contiguous()

        merge_coords = points[:, 0].int() * self.scale_xyz + \
                       points_coords[:, 0] * self.scale_yz + \
                       points_coords[:, 1] * self.scale_z + \
                       points_coords[:, 2]

        unq_coords, unq_inv, unq_cnt = torch.unique(merge_coords, return_inverse=True, return_counts=True, dim=0)

        points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0)
        f_cluster = points_xyz - points_mean[unq_inv, :]

        f_center = torch.zeros_like(points_xyz)
        f_center[:, 0] = points_xyz[:, 0] - (points_coords[:, 0].to(points_xyz.dtype) * self.voxel_x + self.x_offset)
        f_center[:, 1] = points_xyz[:, 1] - (points_coords[:, 1].to(points_xyz.dtype) * self.voxel_y + self.y_offset)
        # f_center[:, 2] = points_xyz[:, 2] - self.z_offset
        f_center[:, 2] = points_xyz[:, 2] - (points_coords[:, 2].to(points_xyz.dtype) * self.voxel_z + self.z_offset)

        if self.use_absolute_xyz:
            features = [points[:, 1:], f_cluster, f_center]
        else:
            features = [points[:, 4:], f_cluster, f_center]

        if self.with_distance:
            points_dist = torch.norm(points[:, 1:4], 2, dim=1, keepdim=True)
            features.append(points_dist)
        features = torch.cat(features, dim=-1)

        for pfn in self.pfn_layers:
            features = pfn(features, unq_inv)

        # generate voxel coordinates
        unq_coords = unq_coords.int()
        voxel_coords = torch.stack((unq_coords // self.scale_xyz,
                                    (unq_coords % self.scale_xyz) // self.scale_yz,
                                    (unq_coords % self.scale_yz) // self.scale_z,
                                    unq_coords % self.scale_z), dim=1)
        voxel_coords = voxel_coords[:, [0, 3, 2, 1]]

        batch_dict['pillar_features'] = batch_dict['voxel_features'] = features
        batch_dict['voxel_coords'] = voxel_coords

        return batch_dict
dsvt_input_layer.py

准备工作,提前划分好窗口,窗口内的set,voxel的index等

 ''' 
    This class converts the output of vfe to dsvt input.
    We do in this class:
    1. Window partition: partition voxels to non-overlapping windows.
    2. Set partition: generate non-overlapped and size-equivalent local sets within each window.
    3. Pre-compute the downsample infomation between two consecutive stages.
    4. Pre-compute the position embedding vectors.

    Args:
        sparse_shape (tuple[int, int, int]): Shape of input space (xdim, ydim, zdim).
        window_shape (list[list[int, int, int]]): Window shapes (winx, winy, winz) in different stages. Length: stage_num.
        downsample_stride (list[list[int, int, int]]): Downsample strides between two consecutive stages. 
            Element i is [ds_x, ds_y, ds_z], which is used between stage_i and stage_{i+1}. Length: stage_num - 1.
        d_model (list[int]): Number of input channels for each stage. Length: stage_num.
        set_info (list[list[int, int]]): A list of set config for each stage. Eelement i contains 
            [set_size, block_num], where set_size is the number of voxel in a set and block_num is the
            number of blocks for stage i. Length: stage_num.
        hybrid_factor (list[int, int, int]): Control the window shape in different blocks. 
            e.g. for block_{0} and block_{1} in stage_0, window shapes are [win_x, win_y, win_z] and 
            [win_x * h[0], win_y * h[1], win_z * h[2]] respectively.
        shift_list (list): Shift window. Length: stage_num.
        normalize_pos (bool): Whether to normalize coordinates in position embedding.
    '''
dsvt.py

根据配置文件的两层attention操作。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值