PointNeXt
PointNet++ 是用于点云理解的最有影响力的神经架构之一。尽管 PointNet++ 的准确性已被 PointMLP 和 Point Transformer 等最近的网络在很大程度上超越,但我们发现很大一部分性能提升是由于改进了训练策略,即数据增强和优化技术,以及增加了模型大小而不是架构创新。因此,PointNet++ 的全部潜力还有待探索。
针对PointNet++网络的修改
PointNeXt主要做了两个方面的事情,第一方面是数据增强,第二个是修改了模型部分。
本文主要针对的是在模型部分的修改进行了详细的介绍,并不涉及数据增强部分的内容。具体的网络模型如下图所示。
与PointNet++不同的是,在SA层后面又添加了一层InvResMLP,以此来缓解梯度消失的问题。
假设输入的点云是[N, 4]维度的。这时是第一个输入,也就是head,其中N表示,这一个场景内的点云总共由N个点组成,每一个点由4个特征值表示。首先经过一次Conv1d,变成[N, 32]维度,之后经过一个SA层和一个InvResMLP变成[N/4, 64]。下面来结合代码看一下这个SA层和InvResMLP具体的工作原理。
class SetAbstraction(nn.Module):
def __init__(self,
in_channels, out_channels,
layers=1,
stride=1,
group_args={'NAME': 'ballquery',
'radius': 0.1, 'nsample': 16},
norm_args={'norm': 'bn1d'},
act_args={'act': 'relu'},
conv_args=None,
sample_method='fps',
use_res=False,
is_head=False,
):
super().__init__()
# is_head:表示是否为初始输入的点云,如果是则使用conv1d,如果不是,则使用conv2d。
# all_aggr:表示是否要把所有的点group成1个
self.stride = stride
self.is_head = is_head
# current blocks aggregates all spatial information.
self.all_aggr = not is_head and stride == 1
# use_res = False
self.use_res = use_res and not self.all_aggr and not self.is_head
mid_channel = out_channels // 2 if stride > 1 else out_channels
channels = [in_channels] + [mid_channel] * \
(layers - 1) + [out_channels]
# 如果不是head输出通道就要加xyz_position
channels[0] = in_channels + 3 * (not is_head)
# 如果是head,则使用conv1d,如果不是,则使用conv2d。
create_conv = create_convblock1d if is_head else create_convblock2d
convs = []
for i in range(len(channels) - 1):
convs.append(create_conv(channels[i], channels[i + 1],
norm_args=norm_args if not is_head else None,
act_args=None if i == len(channels) - 2
and (self.use_res or is_head) else act_args,
**conv_args)
)
self.convs = nn.Sequential(*convs)
# 如果不是head,则需要进行下采样
if not is_head:
if self.all_aggr:
# 如果是对所有的点进行,则不需要下采样的点数以及半径。
group_args.nsample = None
group_args.radius = None
self.grouper = create_grouper(group_args)
self.pool = lambda x: torch.max(x, dim=-1, keepdim=False)[0]
if sample_method.lower() == 'fps':
self.sample_fn = furthest_point_sample
elif sample_method.lower() == 'random':
self.sample_fn = random_sample
def forward(self, px):
p, x = px
if self.is_head:
x = self.convs(x) # (n, c)
else:
if not self.all_aggr:
idx = self.sample_fn(p, p.shape[1] // self.stride).long()
new_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
else:
new_p = p
# [dp,xj]: (B, 3 + C, npoint, nsample)
dp, xj = self.grouper(new_p, p, x)
# pool:torch.max(x, dim=-1, keepdim=False)[0], 由于dim=-1,所以是将最后一维度的nsample变为1
# 对照下文也就是K的那一维度被max_pool.
x = self.pool(self.convs(torch.cat((dp, xj), dim=1)))
if self.use_res:
x = self.act(x + identity)
p = new_p
return p, x
假设nsample=32,那么输入[B, N, 32]的点云经过SA层的每一层的输出分别为,经过subsample变为[B, N/4, 32],经过Grouping变为[B, N/4, 32(K), 32],经过MLP变为[B, N/4, 32(K), 64],经过Reduction变为[B, N/4, 64]。后续的从64-128,128-256,256-512的操作与此相似,不再过多赘述。
接下来我们再来看InvResMLP层都做了那些事情,此时经过SA层我们得到的输出是[B, N/4, 64],以此作为InvResMLP层的输入。结合InvResMLP层的代码来看。
class InvResMLP(nn.Module):
def __init__(self,
in_channels,
norm_args=None,
act_args=None,
aggr_args={'feature_type': 'dp_fj', "reduction": 'max'},
group_args={'NAME': 'ballquery'},
conv_args=None,
expansion=1,
use_res=True,
num_posconvs=2,
less_act=False,
**kwargs
):
super().__init__()
self.use_res = use_res
mid_channels = in_channels * expansion
self.convs = LocalAggregation([in_channels, in_channels],
norm_args=norm_args, act_args=act_args if num_posconvs > 0 else None,
group_args=group_args, conv_args=conv_args,
**aggr_args, **kwargs)
if num_posconvs < 1:
channels = []
elif num_posconvs == 1:
channels = [in_channels, in_channels]
else:
channels = [in_channels, mid_channels, in_channels]
pwconv = []
# point wise after depth wise conv (without last layer)
for i in range(len(channels) - 1):
pwconv.append(create_convblock1d(channels[i], channels[i + 1],
norm_args=norm_args,
act_args=act_args if
(i != len(channels) - 2) and not less_act else None,
**conv_args)
)
self.pwconv = nn.Sequential(*pwconv)
self.act = create_act(act_args)
def forward(self, px):
p, x = px
# identity 就是输入的feature
identity = x
x = self.convs([p, x])
x = self.pwconv(x)
# 判断是否use_res,如果use_res则将之前保存的x直接加到现在的x上。
if x.shape[-1] == identity.shape[-1] and self.use_res:
x += identity
x = self.act(x)
return [p, x]
其中的LocalAggregation这一部分执行的是图中Grouping,MLP(64)以及reduction这一部分的操作,Grouping与SA层中的非head的Grouping操作基本一致,MLP(64)做的事情是将group完成的数据进行conv2d,然后进行max_pool。这三步结束后得到的输出是[B, N/4, 256]。之后经过两次conv1d,第一次conv1d之后输出为[B, N/4, 64],第二次conv1d之后输出为[B, N/4, 64],这时第二次conv1d的输出和最开始Grouping之前的输入进行一个对应项的相加,两个[B, N/4, 64]的相加(注意是相加而不是concat)。最终的输出仍为[B, N/4, 64]。之后的每一个InvResMLP都与此完全一致。
上采样(FeaturePropogation)部分与PointNet++相同。
总结
相比PointNet++,PointNeXt的网络模型有以下3点调整:
①在最初的输入后面添加了一个单独的MLP
②没有使用PointNet++中的多尺度Msg
③添加了InvResMLP