PointNeXt网络部分详解

PointNeXt

PointNet++ 是用于点云理解的最有影响力的神经架构之一。尽管 PointNet++ 的准确性已被 PointMLP 和 Point Transformer 等最近的网络在很大程度上超越,但我们发现很大一部分性能提升是由于改进了训练策略,即数据增强和优化技术,以及增加了模型大小而不是架构创新。因此,PointNet++ 的全部潜力还有待探索。

针对PointNet++网络的修改

PointNeXt主要做了两个方面的事情,第一方面是数据增强,第二个是修改了模型部分。

本文主要针对的是在模型部分的修改进行了详细的介绍,并不涉及数据增强部分的内容。具体的网络模型如下图所示。

在这里插入图片描述

与PointNet++不同的是,在SA层后面又添加了一层InvResMLP,以此来缓解梯度消失的问题。

假设输入的点云是[N, 4]维度的。这时是第一个输入,也就是head,其中N表示,这一个场景内的点云总共由N个点组成,每一个点由4个特征值表示。首先经过一次Conv1d,变成[N, 32]维度,之后经过一个SA层和一个InvResMLP变成[N/4, 64]。下面来结合代码看一下这个SA层和InvResMLP具体的工作原理。

class SetAbstraction(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 layers=1,
                 stride=1,
                 group_args={'NAME': 'ballquery',
                             'radius': 0.1, 'nsample': 16},
                 norm_args={'norm': 'bn1d'},
                 act_args={'act': 'relu'},
                 conv_args=None,
                 sample_method='fps',
                 use_res=False,
                 is_head=False,
                 ):
        super().__init__()
        # is_head:表示是否为初始输入的点云,如果是则使用conv1d,如果不是,则使用conv2d。
        # all_aggr:表示是否要把所有的点group成1个
        self.stride = stride
        self.is_head = is_head
        # current blocks aggregates all spatial information.
        self.all_aggr = not is_head and stride == 1
        # use_res = False
        self.use_res = use_res and not self.all_aggr and not self.is_head
        mid_channel = out_channels // 2 if stride > 1 else out_channels
        channels = [in_channels] + [mid_channel] * \
            (layers - 1) + [out_channels]
        # 如果不是head输出通道就要加xyz_position
        channels[0] = in_channels + 3 * (not is_head)
        # 如果是head,则使用conv1d,如果不是,则使用conv2d。
        create_conv = create_convblock1d if is_head else create_convblock2d
        convs = []
        for i in range(len(channels) - 1):
            convs.append(create_conv(channels[i], channels[i + 1],
                                     norm_args=norm_args if not is_head else None,
                                     act_args=None if i == len(channels) - 2
                                     and (self.use_res or is_head) else act_args,
                                     **conv_args)
                         )
        self.convs = nn.Sequential(*convs)
        # 如果不是head,则需要进行下采样
        if not is_head:
            if self.all_aggr:
                # 如果是对所有的点进行,则不需要下采样的点数以及半径。
                group_args.nsample = None
                group_args.radius = None
            self.grouper = create_grouper(group_args)
            self.pool = lambda x: torch.max(x, dim=-1, keepdim=False)[0]
            if sample_method.lower() == 'fps':
                self.sample_fn = furthest_point_sample
            elif sample_method.lower() == 'random':
                self.sample_fn = random_sample
                
    def forward(self, px):
        p, x = px
        if self.is_head:
            x = self.convs(x)  # (n, c)
        else:
            if not self.all_aggr:
                idx = self.sample_fn(p, p.shape[1] // self.stride).long()
                new_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
            else:
                new_p = p
            # [dp,xj]: (B, 3 + C, npoint, nsample)
            dp, xj = self.grouper(new_p, p, x)
            # pool:torch.max(x, dim=-1, keepdim=False)[0], 由于dim=-1,所以是将最后一维度的nsample变为1
            # 对照下文也就是K的那一维度被max_pool.
            x = self.pool(self.convs(torch.cat((dp, xj), dim=1)))
            if self.use_res:
                x = self.act(x + identity)
            p = new_p
        return p, x

假设nsample=32,那么输入[B, N, 32]的点云经过SA层的每一层的输出分别为,经过subsample变为[B, N/4, 32],经过Grouping变为[B, N/4, 32(K), 32],经过MLP变为[B, N/4, 32(K), 64],经过Reduction变为[B, N/4, 64]。后续的从64-128,128-256,256-512的操作与此相似,不再过多赘述。

接下来我们再来看InvResMLP层都做了那些事情,此时经过SA层我们得到的输出是[B, N/4, 64],以此作为InvResMLP层的输入。结合InvResMLP层的代码来看。

class InvResMLP(nn.Module):
    def __init__(self,
                 in_channels,
                 norm_args=None,
                 act_args=None,
                 aggr_args={'feature_type': 'dp_fj', "reduction": 'max'},
                 group_args={'NAME': 'ballquery'},
                 conv_args=None,
                 expansion=1,
                 use_res=True,
                 num_posconvs=2,
                 less_act=False,
                 **kwargs
                 ):
        super().__init__()
        self.use_res = use_res
        mid_channels = in_channels * expansion
        self.convs = LocalAggregation([in_channels, in_channels],
                                      norm_args=norm_args, act_args=act_args if num_posconvs > 0 else None,
                                      group_args=group_args, conv_args=conv_args,
                                      **aggr_args, **kwargs)
        if num_posconvs < 1:
            channels = []
        elif num_posconvs == 1:
            channels = [in_channels, in_channels]
        else:
            channels = [in_channels, mid_channels, in_channels]
        pwconv = []
        # point wise after depth wise conv (without last layer)
        for i in range(len(channels) - 1):
            pwconv.append(create_convblock1d(channels[i], channels[i + 1],
                                             norm_args=norm_args,
                                             act_args=act_args if
                                             (i != len(channels) - 2) and not less_act else None,
                                             **conv_args)
                          )
        self.pwconv = nn.Sequential(*pwconv)
        self.act = create_act(act_args)

    def forward(self, px):
        p, x = px
        # identity 就是输入的feature
        identity = x
        x = self.convs([p, x])
        x = self.pwconv(x)
        # 判断是否use_res,如果use_res则将之前保存的x直接加到现在的x上。
        if x.shape[-1] == identity.shape[-1] and self.use_res:
            x += identity
        x = self.act(x)
        return [p, x]

其中的LocalAggregation这一部分执行的是图中Grouping,MLP(64)以及reduction这一部分的操作,Grouping与SA层中的非head的Grouping操作基本一致,MLP(64)做的事情是将group完成的数据进行conv2d,然后进行max_pool。这三步结束后得到的输出是[B, N/4, 256]。之后经过两次conv1d,第一次conv1d之后输出为[B, N/4, 64],第二次conv1d之后输出为[B, N/4, 64],这时第二次conv1d的输出和最开始Grouping之前的输入进行一个对应项的相加,两个[B, N/4, 64]的相加(注意是相加而不是concat)。最终的输出仍为[B, N/4, 64]。之后的每一个InvResMLP都与此完全一致。

上采样(FeaturePropogation)部分与PointNet++相同。

总结

相比PointNet++,PointNeXt的网络模型有以下3点调整:

①在最初的输入后面添加了一个单独的MLP

②没有使用PointNet++中的多尺度Msg

③添加了InvResMLP

  • 5
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值