PointTransformerLayer
https://github.com/POSTECH-CVLab/point-transformer/blob/master/model/pointtransformer/pointtransformer_seg.py
class PointTransformerLayer(nn.Module):
def __init__(self, in_planes, out_planes, share_planes=8, nsample=16):
super().__init__()
self.mid_planes = mid_planes = out_planes // 1
self.out_planes = out_planes
self.share_planes = share_planes
self.nsample = nsample
self.linear_q = nn.Linear(in_planes, mid_planes)
self.linear_k = nn.Linear(in_planes, mid_planes)
self.linear_v = nn.Linear(in_planes, out_planes)
self.linear_p = nn.Sequential(nn.Linear(3, 3), nn.BatchNorm1d(3), nn.ReLU(inplace=True), nn.Linear(3, out_planes))
self.linear_w = nn.Sequential(nn.BatchNorm1d(mid_planes), nn.ReLU(inplace=True),
nn.Linear(mid_planes, mid_planes // share_planes),
nn.BatchNorm1d(mid_planes // share_planes), nn.ReLU(inplace=True),
nn.Linear(out_planes // share_planes, out_planes // share_planes))
self.softmax = nn.Softmax(dim=1)
def forward(self, pxo) -> torch.Tensor:
p, x, o = pxo # (n, 3), (n, c), (b)
x_q, x_k, x_v = self.linear_q(x), self.linear_k(x), self.linear_v(x) # (n, c)
x_k = pointops.queryandgroup(self.nsample, p, p, x_k, None, o, o, use_xyz=True) # (n, nsample, 3+c)
x_v = pointops.queryandgroup(self.nsample, p, p, x_v, None, o, o, use_xyz=False) # (n, nsample, c)
p_r, x_k = x_k[:, :, 0:3], x_k[:, :, 3:]
for i, layer in enumerate(self.linear_p): p_r = layer(p_r.transpose(1, 2).contiguous()).transpose(1, 2).contiguous() if i == 1 else layer(p_r) # (n, nsample, c)
w = x_k - x_q.unsqueeze(1) + p_r.view(p_r.shape[0], p_r.shape[1], self.out_planes // self.mid_planes, self.mid_planes).sum(2) # (n, nsample, c)
for i, layer in enumerate(self.linear_w): w = layer(w.transpose(1, 2).contiguous()).transpose(1, 2).contiguous() if i % 3 == 0 else layer(w)
w = self.softmax(w) # (n, nsample, c)
n, nsample, c = x_v.shape; s = self.share_planes
x = ((x_v + p_r).view(n, nsample, s, c // s) * w.unsqueeze(2)).sum(1).view(n, c)
return x
解读:
邻居特征集合x_k和位置特征p_r,减去单个中心点的特征x_q,生成了权重w
邻居特征集合x_v和位置特征p_r,和w做乘法,再针对邻居维度进行sum运算,得到了最终的特征。