- 论文:PointSIFT: A SIFT-like Network Module for 3D Point Cloud Semantic Segmentation
- 来源:ECCV 2020
- 机构:清华大学 & 上海交通大学
- 论文:http://arxiv.org/abs/1807.00652
- 代码:
- https://github.com/MVIG-SJTU/pointSIFT/
- https://github.com/lelouedec/3DNetworksPytorch
- 灵感来源:来着SIFT方法,该方法是提取2D图像中图像特征的方法,来生成类似图像中形状的描述
- 方法实现:设计一个网络PointSIFT,能够提取物体中方向的信息同时能适应不同尺度的物体
- 实验结果:在ScanNet数据集上语义分割任务中,相比PointNet++网络IOU提升2.3%,Acc提升3.22%
网络整体
整体网络如下,PointSIFT对PointNet++中的MLP层进行了改进。
PointSIFT模块
通过堆叠多个 OE 单元,使得网络能够适应不同尺度的形状特征。在 PointSIFT 模块中堆叠多个 OE 单元,每个单元的感受野逐渐增大,以捕捉更大范围内的局部特征。使用shortcuts将各级 OE 单元的输出连接起来,然后通过一个点卷积(point-wise convolution)将多尺度特征融合,最终输出具有多尺度感知能力的特征表示。
Orientation-Encoding Unit
OE单元是一个方向编码单元来描述八个关键方向。
- 其中主要两个步骤:
- 八邻域搜索(Stacked 8-Neighborhood Search)对于每个输入点,PointSIFT 首先按照坐标轴将空间划分为八个象限,并在每个象限中找到距离最近的点作为邻居。如果某个象限内没有点,则将输入点自身复制为其最近邻。
方向编码卷积(Orientation-Encoding Convolution, OEC):在八邻域点上进行三阶段的卷积操作(沿 X, Y, Z 轴),将这些点的特征进行融合。具体来说,卷积操作将 2×2×2 立方体中的特征依次沿各轴进行卷积,最终输出包含方向信息的特征表示。
- 八邻域搜索(Stacked 8-Neighborhood Search)对于每个输入点,PointSIFT 首先按照坐标轴将空间划分为八个象限,并在每个象限中找到距离最近的点作为邻居。如果某个象限内没有点,则将输入点自身复制为其最近邻。
代码实现
我在文章顶端中,有俩个PointSIFT实现版本分别是pytorch和tensorflow实现,大家根据需求迁移到自己的项目中,下面我讲以pytorch版本网络设计对代码进行注释,具体的实现细节可以参考上面的链接中
class PointSIFT(nn.Module):
def __init__(self, nb_classes):
super(PointSIFT, self).__init__()
self.num_classes = nb_classes
# 第一个 PointSIFT 残差模块,用于提取局部特征,半径为 0.1,输出通道为 64。
self.pointsift_res_m3 = PointSIFT_res_module(radius=0.1, output_channel=64, merge='concat')
# 第一个 PointNet 下采样模块,采样 1024 个点,半径为 0.1,32 个邻居点,输出特征维度为 128。
self.pointnet_sa_m3 = Pointnet_SA_module(npoint=1024, radius=0.1, nsample=32, in_channel=64, mlp=[64, 128], group_all=False)
# 第二个 PointSIFT 残差模块,半径为 0.2,输出通道为 128。
self.pointsift_res_m4 = PointSIFT_res_module(radius=0.2, output_channel=128, extra_input_channel=128)
# 第二个 PointNet 下采样模块,采样 256 个点,半径为 0.2,32 个邻居点,输出特征维度为 256。
self.pointnet_sa_m4 = Pointnet_SA_module(npoint=256, radius=0.2, nsample=32, in_channel=128, mlp=[128, 256], group_all=False)
# 第三个 PointSIFT 残差模块,第一个子模块,半径为 0.2,输出通道为 256。
self.pointsift_res_m5_1 = PointSIFT_res_module(radius=0.2, output_channel=256, extra_input_channel=256)
# 第三个 PointSIFT 残差模块,第二个子模块,半径为 0.2,输出通道为 512。
self.pointsift_res_m5_2 = PointSIFT_res_module(radius=0.2, output_channel=512, extra_input_channel=256, same_dim=True)
# 1D 卷积层,用于合并和处理 PointSIFT 残差模块的输出,输入通道 768,输出通道 512。
self.conv1 = conv1d_bn(768, 512, 1, stride=1, activation='none')
# 第四个 PointNet 下采样模块,采样 64 个点,半径为 0.2,32 个邻居点,输出特征维度为 512。
self.pointnet_sa_m6 = Pointnet_SA_module(npoint=64, radius=0.2, nsample=32, in_channel=512, mlp=[512, 512], group_all=False)
# 第一个 PointNet 上采样模块,将较低分辨率特征恢复到较高分辨率,输入和输出通道均为 512。
self.pointnet_fp_m0 = Pointnet_fp_module([512, 512], [512, 512])
# 一系列 PointSIFT 模块,半径为 0.5,输出通道均为 512。
self.pointsift_m0 = PointSIFT_module(radius=0.5, output_channel=512, extra_input_channel=512)
self.pointsift_m1 = PointSIFT_module(radius=0.5, output_channel=512, extra_input_channel=512)
self.pointsift_m2 = PointSIFT_module(radius=0.5, output_channel=512, extra_input_channel=512)
# 1D 卷积层,用于合并 PointSIFT 模块的输出,输入通道 512,输出通道 512。
self.conv2 = conv1d_bn(512, 512, 1, stride=1, activation='none')
# 第二个 PointNet 上采样模块,将较低分辨率特征恢复到较高分辨率,输入和输出通道均为 256。
self.pointnet_fp_m1 = Pointnet_fp_module([256, 256], [256, 256])
# 一系列 PointSIFT 模块,半径为 0.25,输出通道均为 256。
self.pointsift_m3 = PointSIFT_module(radius=0.25, output_channel=256, extra_input_channel=256)
self.pointsift_m4 = PointSIFT_module(radius=0.25, output_channel=256, extra_input_channel=256)
# 1D 卷积层,用于合并 PointSIFT 模块的输出,输入通道 256,输出通道 256。
self.conv3 = conv1d_bn(256, 256, 1, stride=1, activation='none')
# 第三个 PointNet 上采样模块,将较低分辨率特征恢复到较高分辨率,输入和输出通道均为 128。
self.pointnet_fp_m2 = Pointnet_fp_module([128, 128, 128], [128, 128, 128])
# 第五个 PointSIFT 模块,半径为 0.1,输出通道为 128。
self.pointsift_m5 = PointSIFT_module(radius=0.1, output_channel=128, extra_input_channel=128)
### 全连接层
# 1D 卷积层,用于分类任务的全连接层,输入通道 128,输出通道 128。
self.conv_fc = conv1d_bn(128, 128, 1, stride=1, activation='none')
# Dropout 层,防止过拟合,保留率为 0.5。
self.drop_fc = nn.Dropout(p=0.5)
# 最终的分类层,1D 卷积层,输出通道数为 2(假设分类为 2 类)。
self.conv2_fc = conv1d_bn(128, 2, 1, stride=1, activation='none')
def forward(self, xyz, points=None):
"""
Input:
xyz: 输入的点云数据,形状为 (B * N * 3),B 为批次大小,N 为点的数量,3 为坐标维度。
points: 附加的点特征,默认为 None。
"""
B = xyz.size()[0]
# 第一层 PointSIFT 残差模块 + PointNet 下采样模块
l3_xyz, l3_points = self.pointsift_res_m3(xyz, points)
c3_xyz, c3_points = self.pointnet_sa_m3(l3_xyz, l3_points)
# 第二层 PointSIFT 残差模块 + PointNet 下采样模块
l4_xyz, l4_points = self.pointsift_res_m4(c3_xyz, c3_points)
c4_xyz, c4_points = self.pointnet_sa_m4(l4_xyz, l4_points)
# 第三层 PointSIFT 残差模块
l5_xyz, l5_points = self.pointsift_res_m5_1(c4_xyz, c4_points)
l5_2_xyz, l5_2_points = self.pointsift_res_m5_2(l5_xyz, l5_points)
# 将第三层 PointSIFT 残差模块的两个输出拼接在一起
l2_cat_points = torch.cat([l5_points, l5_2_points], dim=2)
# 通过 1D 卷积层处理拼接后的特征
fc_l2_points = self.conv1(l2_cat_points.permute(0, 2, 1)).permute(0, 2, 1)
# 第四层 PointNet 下采样模块
l3b_xyz, l3b_points = self.pointnet_sa_m6(l5_2_xyz, fc_l2_points)
# 第一层 PointNet 上采样模块
l2_points = self.pointnet_fp_m0(c4_xyz, l3b_xyz, c4_points, l3_points).permute(0, 2, 1)
_, l2_points_1 = self.pointsift_m0(c4_xyz, l2_points)
_, l2_points_2 = self.pointsift_m1(c4_xyz, l2_points)
_, l2_points_3 = self.pointsift_m2(c4_xyz, l2_points)
# 将 PointSIFT 模块的输出拼接在一起并通过 1D 卷积层处理
l2_points = torch.cat([l2_points_1, l2_points_2, l2_points_3], dim=-1)
l2_points = self.conv2(l2_points)
# 第二层 PointNet 上采样模块
l1_points = self.pointnet_fp_m1(c3_xyz, c4_xyz, c3_points, l2_points).permute(0, 2, 1)
_, l1_points_1 = self.points
_, l1_points_1 = self.pointsift_m3(c3_xyz, l1_points)
_, l1_points_2 = self.pointsift_m4(c3_xyz, l1_points)
# 将 PointSIFT 模块的输出拼接在一起
l1_points = torch.cat([l1_points_1, l1_points_2], dim=-1)
# 通过 1D 卷积层处理拼接后的特征
l0_points = self.conv3(l1_points)
# 第三层 PointNet 上采样模块
l0_points = self.pointnet_fp_m2(l3_xyz, c3_xyz, l3_points, l0_points).permute(0, 2, 1)
# 最后一层 PointSIFT 模块
_, l0_points_1 = self.pointsift_m5(l3_xyz, l0_points)
# 全连接层用于分类
net = self.conv_fc(l0_points_1)
net = self.drop_fc(net)
net = self.conv2_fc(net)
return net
@staticmethod
def get_loss(input, target):
"""
损失函数,计算输入与目标之间的交叉熵损失。
"""
classify_loss = nn.CrossEntropyLoss()
loss = classify_loss(input, target)
return loss
def initialize_weights(self):
"""
权重初始化方法,对模型中的卷积层、批量归一化层和全连接层进行初始化。
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
如何改进PointNet++
这篇论文中贡献了两个模块分别是PointSIFT和Orientation-Encoding Unit,这两个模块能够直接替换PointNet++中MLP层,或者跟其他模块在进行组合形成新的模块进行创新模块。