[ICCV2021 Oral] PoinTr 代码理解
完全是个人理解
整体代码
首先贴上所有代码,来自PointTr源码
随后按照forward对代码进行解释。
class PoinTr(nn.Module):
def __init__(self, **kwargs):
super().__init__()
# num_pred: 6144, num_query: 96, knn_layer: 1, trans_dim: 384
self.trans_dim = 384
self.knn_layer = 1
self.num_pred = 6144
self.num_query = 96
self.fold_step = int(pow(self.num_pred//self.num_query, 0.5) + 0.5)
self.base_model = PCTransformer(in_chans = 3, embed_dim = self.trans_dim, depth = [6, 8], drop_rate = 0., num_query = self.num_query, knn_layer = self.knn_layer)
self.foldingnet = Fold(self.trans_dim, step = self.fold_step, hidden_dim = 256) # rebuild a cluster point
self.increase_dim = nn.Sequential(
nn.Conv1d(self.trans_dim, 1024, 1),
nn.BatchNorm1d(1024),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv1d(1024, 1024, 1)
)
self.reduce_map = nn.Linear(self.trans_dim + 1027, self.trans_dim)
self.build_loss_func()
def build_loss_func(self):
self.loss_func = ChamferDistanceL1()
def get_loss(self, ret, gt):
loss_coarse = self.loss_func(ret[0], gt)
loss_fine = self.loss_func(ret[1], gt)
return loss_coarse, loss_fine
def forward(self, xyz):
q, coarse_point_cloud = self.base_model(xyz) # B M C and B M 3
B, M ,C = q.shape
global_feature = self.increase_dim(q.transpose(1,2)).transpose(1,2) # B M 1024
global_feature = torch.max(global_feature, dim=1)[0] # B 1024
rebuild_feature = torch.cat([
global_feature.unsqueeze(-2).expand(-1, M, -1),
q,
coarse_point_cloud], dim=-1) # B M 1027 + C
rebuild_feature = self.reduce_map(rebuild_feature.reshape(B*M, -1)) # BM C
# # NOTE: try to rebuild pc
# coarse_point_cloud = self.refine_coarse(rebuild_feature).reshape(B, M, 3)
# NOTE: foldingNet
relative_xyz = self.foldingnet(rebuild_feature).reshape(B, M, 3, -1) # B M 3 S
rebuild_points = (relative_xyz + coarse_point_cloud.unsqueeze(-1)).transpose(2,3).reshape(B, -1, 3) # B N 3
# NOTE: fc
# relative_xyz = self.refine(rebuild_feature) # BM 3S
# rebuild_points = (relative_xyz.reshape(B,M,3,-1) + coarse_point_cloud.unsqueeze(-1)).transpose(2,3).reshape(B, -1, 3)
# cat the input
inp_sparse = fps(xyz, self.num_query)
coarse_point_cloud = torch.cat([coarse_point_cloud, inp_sparse], dim=1).contiguous()
rebuild_points = torch.cat([rebuild_points, xyz],dim=1).contiguous()
ret = (coarse_point_cloud, rebuild_points)
return ret
Point Proxy
点云
{
q
1
,
q
2
,
…
,
q
N
}
\{q_1,q_2,…,q_N\}
{q1,q2,…,qN},论文中
N
=
2048
N=2048
N=2048。Point Proxy 表示点云的局部区域。Point Proxy 是一个特征向量,它捕获点周围的局部结构,可以计算为
其中
F
i
′
F'_i
Fi′ 是使用 DGCNN 模型提取的点
q
i
q_i
qi 的特征,而
φ
φ
φ 是另一个用于捕获 Point Proxy 局部信息的MLP。
DGCNN_Grouper
正如论文里所讲的,第一步是使用带有分层下采样的轻量级DGCNN从输入点云中提取点中心的特征。
# 输入点云[B, N, 3],这里N=2048
coor, f = self.grouper(inpc.transpose(1,2).contiguous())
具体来看grouper函数,其实就是DGCNN_Grouper
# 首先使用 input_trans (conv1d)
coor = x
f = self.input_trans(x) # conv1d(3,8)
# 对应coor[B, 3, 2048]中每一个点,在本身中找到k邻近,根据k邻近获取对应的特征f
f = self.get_graph_feature(coor, f, coor, f)
f = self.layer1(f)
f = f.max(dim=-1, keepdim=False)[0]
# fps下采样,采样至512
coor_q, f_q = self.fps_downsample(coor, f, 512)
# 对应coor_q[B, 3, 512]中每一个点,在coor[B, 3, 2048]中找到k邻近,根据k邻近获取对应的特征f
f = self.get_graph_feature(coor_q, f_q, coor, f)
f = self.layer2(f)
f = f.max(dim=-1, keepdim=False)[0]
coor = coor_q
# 对应coor[B, 3, 512]中每一个点,在本身中找到k邻近,根据k邻近获取对应的特征f
f = self.get_graph_feature(coor, f, coor, f)
f = self.layer3(f)
f = f.max(dim=-1, keepdim=False)[0]
# fps下采样,采样至128
coor_q, f_q = self.fps_downsample(coor, f, 128)
# 对应coor_q[B, 3, 128]中每一个点,在coor[B, 3, 512]中找到k邻近,根据k邻近获取对应的特征f
f = self.get_graph_feature(coor_q, f_q, coor, f)
f = self.layer4(f)
f = f.max(dim=-1, keepdim=False)[0]
coor = coor_q
return coor, f
看一下函数 get_graph_feature
对于coor_q中的一点,在coor_k中查找K邻近,这里K=16,记录索引ind
根据ind,从x_k中提取特征,记为feature
最后,feature = torch.cat([feature-x_q, x_q], 1)
def get_graph_feature(coor_q, x_q, coor_k, x_k):
# coor: bs, 3, np, x: bs, c, np
k = 16
batch_size = x_k.size(0)
num_points_k = x_k.size(2)
num_points_q = x_q.size(2)
with torch.no_grad():
_, idx = knn(coor_k, coor_q) # bs k np
assert idx.shape[1] == k
idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k
idx = idx + idx_base
idx = idx.view(-1)
num_dims = x_k.size(1)
x_k = x_k.transpose(2, 1).contiguous()
feature = x_k.view(batch_size * num_points_k, -1)[idx, :]
feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous()
x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k)
feature = torch.cat((feature - x_q, x_q), dim=1)
return feature
整体结构如下,最左列是输入,最右列是输出。
MLP
这一部分的灵感来源于transformer中的 position embedding,该操作对Point Proxy的全局位置进行编码。
pos = self.pos_embed(coor).transpose(1,2)
之后又对通过 DGCNN 提取的特征f进行一次 input_proj 操作
x = self.input_proj(f).transpose(1,2)
最后,pos 和 x 相加就得到了 Point Proxy。
Geometry-aware Transformer Block
为了帮助 transformer 更好地利用点云三维几何结构的感应偏压,作者设计了一个 Geometry-aware 来模拟几何关系,它可以是一个即插即用模块,与任何 transformer 架构中的 attention 结合。
使用 KNN 模型来捕捉点云中的几何关系。在给定查询坐标 p Q p_Q pQ 情况下,我们根据 key p K p_K pK 查询邻近key的特征。然后,按照DGCNN,通过线性层的特征聚合,然后进行最大池化操作,来学习局部几何结构。然后将几何特征和语义特征连接并映射到原始维度以形成输出。
点云生成过程
论文逐步生成两个点云,一个是coarse_point_cloud,一个是最终点云。
coarse_point_cloud 生成过程
先生成一个粗点云,代码如下:
for i, blk in enumerate(self.encoder):
if i < self.knn_layer:
x = blk(x + pos, knn_index) # B N C
else:
x = blk(x + pos)
正如上文所说的,x+pos就是 Point Proxy
self.knn_layer 是进行knn的次数,作者设置为1,也就是进行一次knn。
现在对 blk 进行代码解释,这里应用了Geometry-aware Transformer Block,主要来看一下forward部分。
def forward(self, x, knn_index = None):
# x = x + self.drop_path(self.attn(self.norm1(x)))
norm_x = self.norm1(x)
x_1 = self.attn(norm_x)
if knn_index is not None:
knn_f = get_graph_feature(norm_x, knn_index)
knn_f = self.knn_map(knn_f)
knn_f = knn_f.max(dim=1, keepdim=False)[0]
x_1 = torch.cat([x_1, knn_f], dim=-1)
x_1 = self.merge_map(x_1)
x = x + self.drop_path(x_1)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
从代码可以看出
- 当输入中的knn_index存在时,输入有两条路径:
attention
get_graph_feature 与DGCNN中的一样,目的都是为了获取特征。
上述两条路径得到的输出,相加后再经过一次mlp。
def get_graph_feature(x, knn_index, x_q=None):
#x: bs, np, c, knn_index: bs*k*np
k = 8
batch_size, num_points, num_dims = x.size()
num_query = x_q.size(1) if x_q is not None else num_points
feature = x.view(batch_size * num_points, num_dims)[knn_index, :]
feature = feature.view(batch_size, k, num_query, num_dims)
x = x_q if x_q is not None else x
x = x.view(batch_size, 1, num_query, num_dims).expand(-1, k, -1, -1)
feature = torch.cat((feature - x, x), dim=-1)
return feature # b k np c
当输入中的knn_index不存在时,输入依次经过 attention 和mlp。
Query Generator
queries作为预测proxies的初始状态,为了确保queries正确反映出完整的点云,我们提出了一个 Query Generator,根据编码器的输出动态生成预测proxies
global_feature = self.increase_dim(x.transpose(1,2)) # B 1024 N
global_feature = torch.max(global_feature, dim=-1)[0] # B 1024
coarse_point_cloud = self.coarse_pred(global_feature).reshape(bs, -1, 3) # B M C(3)
最后,我们连接编码器的全局特征和坐标,并使用MLP生成预测proxies
query_feature = torch.cat([
global_feature.unsqueeze(1).expand(-1, self.num_query, -1),
coarse_point_cloud], dim=-1) # B M C+3
q = self.mlp_query(query_feature.transpose(1,2)).transpose(1,2) # B M C
Multi-Scale Point Cloud Generation
我们编码器-解码器网络的目标是预测不完整点云的缺失部分。然而,我们只能从transformer解码器中获得 缺少proxies 的预测。因此,我们提出了一个多尺度点云生成框架,以在全分辨率下恢复丢失的点云。
接下来,
for i, blk in enumerate(self.decoder):
if i < self.knn_layer:
q = blk(q, x, new_knn_index, cross_knn_index) # B M C
else:
q = blk(q, x)
现在对 blk 进行代码解释,这里也应用了Geometry-aware Transformer Block,看一下forward部分。
def forward(self, q, v, self_knn_index=None, cross_knn_index=None):
# q = q + self.drop_path(self.self_attn(self.norm1(q)))
norm_q = self.norm1(q)
q_1 = self.self_attn(norm_q) # attention
if self_knn_index is not None:
# 这一部分是根据self_knn_index提取特征
knn_f = get_graph_feature(norm_q, self_knn_index)
knn_f = self.knn_map(knn_f)
knn_f = knn_f.max(dim=1, keepdim=False)[0]
q_1 = torch.cat([q_1, knn_f], dim=-1)
q_1 = self.merge_map(q_1)
q = q + self.drop_path(q_1)
norm_q = self.norm_q(q)
norm_v = self.norm_v(v)
q_2 = self.attn(norm_q, norm_v) # attention
if cross_knn_index is not None:
# 这一部分是根据cross_knn_index提取特征
knn_f = get_graph_feature(norm_v, cross_knn_index, norm_q)
knn_f = self.knn_map_cross(knn_f)
knn_f = knn_f.max(dim=1, keepdim=False)[0]
q_2 = torch.cat([q_2, knn_f], dim=-1)
q_2 = self.merge_map_cross(q_2)
q = q + self.drop_path(q_2)
# q = q + self.drop_path(self.attn(self.norm_q(q), self.norm_v(v)))
q = q + self.drop_path(self.mlp(self.norm2(q)))
return q
global_feature = self.increase_dim(q.transpose(1,2)).transpose(1,2) # B M 1024
global_feature = torch.max(global_feature, dim=1)[0] # B 1024
rebuild_feature = torch.cat([
global_feature.unsqueeze(-2).expand(-1, M, -1),
q,
coarse_point_cloud], dim=-1) # B M 1027 + C
rebuild_feature = self.reduce_map(rebuild_feature.reshape(B*M, -1)) # BM C
我们利用FoldingNet恢复以预测proxies为中心的详细局部形状
relative_xyz = self.foldingnet(rebuild_feature).reshape(B, M, 3, -1) # B M 3 S
rebuild_points = (relative_xyz + coarse_point_cloud.unsqueeze(-1)).transpose(2,3).reshape(B, -1, 3) # B N 3
# cat the input
inp_sparse = fps(xyz, self.num_query)
coarse_point_cloud = torch.cat([coarse_point_cloud, inp_sparse], dim=1).contiguous()
rebuild_points = torch.cat([rebuild_points, xyz],dim=1).contiguous()