[ICCV2021 Oral] PoinTr 代码理解


在这里插入图片描述
完全是个人理解

整体代码

首先贴上所有代码,来自PointTr源码
随后按照forward对代码进行解释。

class PoinTr(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        # num_pred: 6144, num_query: 96, knn_layer: 1, trans_dim: 384
        self.trans_dim = 384
        self.knn_layer = 1
        self.num_pred = 6144
        self.num_query = 96

        self.fold_step = int(pow(self.num_pred//self.num_query, 0.5) + 0.5)
        self.base_model = PCTransformer(in_chans = 3, embed_dim = self.trans_dim, depth = [6, 8], drop_rate = 0., num_query = self.num_query, knn_layer = self.knn_layer)
        
        self.foldingnet = Fold(self.trans_dim, step = self.fold_step, hidden_dim = 256)  # rebuild a cluster point

        self.increase_dim = nn.Sequential(
            nn.Conv1d(self.trans_dim, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv1d(1024, 1024, 1)
        )
        self.reduce_map = nn.Linear(self.trans_dim + 1027, self.trans_dim)
        self.build_loss_func()

    def build_loss_func(self):
        self.loss_func = ChamferDistanceL1()

    def get_loss(self, ret, gt):
        loss_coarse = self.loss_func(ret[0], gt)
        loss_fine = self.loss_func(ret[1], gt)
        return loss_coarse, loss_fine

    def forward(self, xyz):
        q, coarse_point_cloud = self.base_model(xyz) # B M C and B M 3
    
        B, M ,C = q.shape

        global_feature = self.increase_dim(q.transpose(1,2)).transpose(1,2) # B M 1024
        global_feature = torch.max(global_feature, dim=1)[0] # B 1024

        rebuild_feature = torch.cat([
            global_feature.unsqueeze(-2).expand(-1, M, -1),
            q,
            coarse_point_cloud], dim=-1)  # B M 1027 + C

        rebuild_feature = self.reduce_map(rebuild_feature.reshape(B*M, -1)) # BM C
        # # NOTE: try to rebuild pc
        # coarse_point_cloud = self.refine_coarse(rebuild_feature).reshape(B, M, 3)

        # NOTE: foldingNet
        relative_xyz = self.foldingnet(rebuild_feature).reshape(B, M, 3, -1)    # B M 3 S
        rebuild_points = (relative_xyz + coarse_point_cloud.unsqueeze(-1)).transpose(2,3).reshape(B, -1, 3)  # B N 3

        # NOTE: fc
        # relative_xyz = self.refine(rebuild_feature)  # BM 3S
        # rebuild_points = (relative_xyz.reshape(B,M,3,-1) + coarse_point_cloud.unsqueeze(-1)).transpose(2,3).reshape(B, -1, 3)

        # cat the input
        inp_sparse = fps(xyz, self.num_query)
        coarse_point_cloud = torch.cat([coarse_point_cloud, inp_sparse], dim=1).contiguous()
        rebuild_points = torch.cat([rebuild_points, xyz],dim=1).contiguous()

        ret = (coarse_point_cloud, rebuild_points)
        return ret

Point Proxy

点云 { q 1 , q 2 , … , q N } \{q_1,q_2,…,q_N\} {q1q2qN},论文中 N = 2048 N=2048 N=2048。Point Proxy 表示点云的局部区域。Point Proxy 是一个特征向量,它捕获点周围的局部结构,可以计算为
在这里插入图片描述
其中 F i ′ F'_i Fi 是使用 DGCNN 模型提取的点 q i q_i qi 的特征,而 φ φ φ 是另一个用于捕获 Point Proxy 局部信息的MLP。

DGCNN_Grouper

正如论文里所讲的,第一步是使用带有分层下采样的轻量级DGCNN从输入点云中提取点中心的特征。

		# 输入点云[B, N, 3],这里N=2048
		coor, f = self.grouper(inpc.transpose(1,2).contiguous()) 

具体来看grouper函数,其实就是DGCNN_Grouper

		 # 首先使用 input_trans (conv1d)
         coor = x
         f = self.input_trans(x) # conv1d(3,8)

		 # 对应coor[B, 3, 2048]中每一个点,在本身中找到k邻近,根据k邻近获取对应的特征f
         f = self.get_graph_feature(coor, f, coor, f)
         f = self.layer1(f)
         f = f.max(dim=-1, keepdim=False)[0]
		
		 # fps下采样,采样至512
         coor_q, f_q = self.fps_downsample(coor, f, 512)
         # 对应coor_q[B, 3, 512]中每一个点,在coor[B, 3, 2048]中找到k邻近,根据k邻近获取对应的特征f
         f = self.get_graph_feature(coor_q, f_q, coor, f)
         f = self.layer2(f)
         f = f.max(dim=-1, keepdim=False)[0]
		
         coor = coor_q
         # 对应coor[B, 3, 512]中每一个点,在本身中找到k邻近,根据k邻近获取对应的特征f
         f = self.get_graph_feature(coor, f, coor, f)
         f = self.layer3(f)
         f = f.max(dim=-1, keepdim=False)[0]
		 
		 # fps下采样,采样至128
         coor_q, f_q = self.fps_downsample(coor, f, 128)
         # 对应coor_q[B, 3, 128]中每一个点,在coor[B, 3, 512]中找到k邻近,根据k邻近获取对应的特征f
         f = self.get_graph_feature(coor_q, f_q, coor, f)
         f = self.layer4(f)
         f = f.max(dim=-1, keepdim=False)[0]
         coor = coor_q
         return coor, f

看一下函数 get_graph_feature
对于coor_q中的一点,在coor_k中查找K邻近,这里K=16,记录索引ind
根据ind,从x_k中提取特征,记为feature
最后,feature = torch.cat([feature-x_q, x_q], 1)

		def get_graph_feature(coor_q, x_q, coor_k, x_k):
		        # coor: bs, 3, np, x: bs, c, np
		        k = 16
		        batch_size = x_k.size(0)
		        num_points_k = x_k.size(2)
		        num_points_q = x_q.size(2)
		        with torch.no_grad():
		            _, idx = knn(coor_k, coor_q)  # bs k np
		            assert idx.shape[1] == k
		            idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k
		            idx = idx + idx_base
		            idx = idx.view(-1)
		        num_dims = x_k.size(1)
		        x_k = x_k.transpose(2, 1).contiguous()
		        feature = x_k.view(batch_size * num_points_k, -1)[idx, :]
		        feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous()
		        x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k)
		        feature = torch.cat((feature - x_q, x_q), dim=1)
		        return feature

整体结构如下,最左列是输入,最右列是输出。
在这里插入图片描述

MLP

这一部分的灵感来源于transformer中的 position embedding,该操作对Point Proxy的全局位置进行编码。

pos =  self.pos_embed(coor).transpose(1,2)

在这里插入图片描述
之后又对通过 DGCNN 提取的特征f进行一次 input_proj 操作

x = self.input_proj(f).transpose(1,2)

在这里插入图片描述
最后,pos 和 x 相加就得到了 Point Proxy。

Geometry-aware Transformer Block

为了帮助 transformer 更好地利用点云三维几何结构的感应偏压,作者设计了一个 Geometry-aware 来模拟几何关系,它可以是一个即插即用模块,与任何 transformer 架构中的 attention 结合。

使用 KNN 模型来捕捉点云中的几何关系。在给定查询坐标 p Q p_Q pQ 情况下,我们根据 key p K p_K pK 查询邻近key的特征。然后,按照DGCNN,通过线性层的特征聚合,然后进行最大池化操作,来学习局部几何结构。然后将几何特征和语义特征连接并映射到原始维度以形成输出。

点云生成过程

论文逐步生成两个点云,一个是coarse_point_cloud,一个是最终点云。

coarse_point_cloud 生成过程

先生成一个粗点云,代码如下:

        for i, blk in enumerate(self.encoder):
            if i < self.knn_layer:
                x = blk(x + pos, knn_index)   # B N C
            else:
                x = blk(x + pos)

正如上文所说的,x+pos就是 Point Proxy
self.knn_layer 是进行knn的次数,作者设置为1,也就是进行一次knn。

现在对 blk 进行代码解释,这里应用了Geometry-aware Transformer Block,主要来看一下forward部分。

    def forward(self, x, knn_index = None):
        # x = x + self.drop_path(self.attn(self.norm1(x)))
        norm_x = self.norm1(x)
        x_1 = self.attn(norm_x)
        if knn_index is not None:
            knn_f = get_graph_feature(norm_x, knn_index)
            knn_f = self.knn_map(knn_f)
            knn_f = knn_f.max(dim=1, keepdim=False)[0]
            x_1 = torch.cat([x_1, knn_f], dim=-1)
            x_1 = self.merge_map(x_1)
        
        x = x + self.drop_path(x_1)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

从代码可以看出

  1. 当输入中的knn_index存在时,输入有两条路径:
    attention
    get_graph_feature 与DGCNN中的一样,目的都是为了获取特征。

上述两条路径得到的输出,相加后再经过一次mlp。

	def get_graph_feature(x, knn_index, x_q=None):
	        #x: bs, np, c, knn_index: bs*k*np
	        k = 8
	        batch_size, num_points, num_dims = x.size()
	        num_query = x_q.size(1) if x_q is not None else num_points
	        feature = x.view(batch_size * num_points, num_dims)[knn_index, :]
	        feature = feature.view(batch_size, k, num_query, num_dims)
	        x = x_q if x_q is not None else x
	        x = x.view(batch_size, 1, num_query, num_dims).expand(-1, k, -1, -1)
	        feature = torch.cat((feature - x, x), dim=-1)
	        return feature  # b k np c

在这里插入图片描述

当输入中的knn_index不存在时,输入依次经过 attention 和mlp。
在这里插入图片描述

Query Generator

queries作为预测proxies的初始状态,为了确保queries正确反映出完整的点云,我们提出了一个 Query Generator,根据编码器的输出动态生成预测proxies

global_feature = self.increase_dim(x.transpose(1,2)) # B 1024 N 
global_feature = torch.max(global_feature, dim=-1)[0] # B 1024
coarse_point_cloud = self.coarse_pred(global_feature).reshape(bs, -1, 3)  #  B M C(3)

最后,我们连接编码器的全局特征和坐标,并使用MLP生成预测proxies

query_feature = torch.cat([
            global_feature.unsqueeze(1).expand(-1, self.num_query, -1), 
            coarse_point_cloud], dim=-1) # B M C+3 
q = self.mlp_query(query_feature.transpose(1,2)).transpose(1,2) # B M C 

Multi-Scale Point Cloud Generation

我们编码器-解码器网络的目标是预测不完整点云的缺失部分。然而,我们只能从transformer解码器中获得 缺少proxies 的预测。因此,我们提出了一个多尺度点云生成框架,以在全分辨率下恢复丢失的点云。
接下来,

for i, blk in enumerate(self.decoder):
            if i < self.knn_layer:
                q = blk(q, x, new_knn_index, cross_knn_index)   # B M C
            else:
                q = blk(q, x)

现在对 blk 进行代码解释,这里也应用了Geometry-aware Transformer Block,看一下forward部分。

def forward(self, q, v, self_knn_index=None, cross_knn_index=None):
        # q = q + self.drop_path(self.self_attn(self.norm1(q)))
        norm_q = self.norm1(q)
        q_1 = self.self_attn(norm_q) # attention
	       
        if self_knn_index is not None:
		      # 这一部分是根据self_knn_index提取特征
            knn_f = get_graph_feature(norm_q, self_knn_index)
            knn_f = self.knn_map(knn_f)
            knn_f = knn_f.max(dim=1, keepdim=False)[0]
            q_1 = torch.cat([q_1, knn_f], dim=-1)
            q_1 = self.merge_map(q_1)
        
        q = q + self.drop_path(q_1)
        norm_q = self.norm_q(q)
        norm_v = self.norm_v(v)
        q_2 = self.attn(norm_q, norm_v) # attention
	
        if cross_knn_index is not None:
		   # 这一部分是根据cross_knn_index提取特征
            knn_f = get_graph_feature(norm_v, cross_knn_index, norm_q)
            knn_f = self.knn_map_cross(knn_f)
            knn_f = knn_f.max(dim=1, keepdim=False)[0]
            q_2 = torch.cat([q_2, knn_f], dim=-1)
            q_2 = self.merge_map_cross(q_2)
        q = q + self.drop_path(q_2)
        # q = q + self.drop_path(self.attn(self.norm_q(q), self.norm_v(v)))
        q = q + self.drop_path(self.mlp(self.norm2(q)))
        return q

global_feature = self.increase_dim(q.transpose(1,2)).transpose(1,2) # B M 1024
global_feature = torch.max(global_feature, dim=1)[0] # B 1024
rebuild_feature = torch.cat([
            global_feature.unsqueeze(-2).expand(-1, M, -1),
            q,
            coarse_point_cloud], dim=-1)  # B M 1027 + C
rebuild_feature = self.reduce_map(rebuild_feature.reshape(B*M, -1)) # BM C

我们利用FoldingNet恢复以预测proxies为中心的详细局部形状

relative_xyz = self.foldingnet(rebuild_feature).reshape(B, M, 3, -1)    # B M 3 S
rebuild_points = (relative_xyz + coarse_point_cloud.unsqueeze(-1)).transpose(2,3).reshape(B, -1, 3)  # B N 3
# cat the input
inp_sparse = fps(xyz, self.num_query)
coarse_point_cloud = torch.cat([coarse_point_cloud, inp_sparse], dim=1).contiguous()
rebuild_points = torch.cat([rebuild_points, xyz],dim=1).contiguous()
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值