代码地址: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_transformer/point_transformer_seg.py
1. PointTransformerLayer:
class PointTransformerLayer(nn.Module):
def __init__(self, in_planes, out_planes, share_planes=8, nsample=16):
super().__init__()
self.mid_planes = mid_planes = out_planes // 1
self.out_planes = out_planes
self.share_planes = share_planes
self.nsample = nsample
self.linear_q = nn.Linear(in_planes, mid_planes)
self.linear_k = nn.Linear(in_planes, mid_planes)
self.linear_v = nn.Linear(in_planes, out_planes)
self.linear_p = nn.Sequential(
nn.Linear(3, 3),
LayerNorm1d(3),
nn.ReLU(inplace=True),
nn.Linear(3, out_planes),
)
self.linear_w = nn.Sequential(
LayerNorm1d(mid_planes),
nn.ReLU(inplace=True),
nn.Linear(mid_planes, out_planes // share_planes),
LayerNorm1d(out_planes // share_planes),
nn.ReLU(inplace=True),
nn.Linear(out_planes // share_planes, out_planes // share_planes),
)
self.softmax = nn.Softmax(dim=1)
def forward(self, pxo) -> torch.Tensor:
p, x, o = pxo # (n, 3), (n, c), (b)
x_q, x_k, x_v = self.linear_q(x), self.linear_k(x), self.linear_v(x)
x_k, idx = pointops.knn_query_and_group(
x_k, p, o, new_xyz=p, new_offset=o, nsample=self.nsample, with_xyz=True
)
x_v, _ = pointops.knn_query_and_group(
x_v,
p,
o,
new_xyz=p,
new_offset=o,
idx=idx,
nsample=self.nsample,
with_xyz=False,
)
p_r, x_k = x_k[:, :, 0:3], x_k[:, :, 3:]
p_r = self.linear_p(p_r)
r_qk = (
x_k
- x_q.unsqueeze(1)
+ einops.reduce(
p_r, "n ns (i j) -> n ns j", reduction="sum", j=self.mid_planes
)
)
w = self.linear_w(r_qk) # (n, nsample, c)
w = self.softmax(w)
x = torch.einsum(
"n t s i, n t i -> n s i",
einops.rearrange(x_v + p_r, "n ns (s i) -> n ns s i", s=self.share_planes),
w,
)
x = einops.rearrange(x, "n s i -> n (s i)")
return x
知识补充:
-
Einops是一个Python库,用于灵活地重新组织张量(tensors)的维度。它提供了简洁而强大的方式来定义和执行各种操作,如转置、重排、合并和拆分张量的维度。pytorch语句也能实现相同的功能。
-
偏移量(Offset)是批量数据中点云之间的分隔符,类似于PyG中的批量(Batch)的概念。批量和偏移量的视觉说明如下:
-
代码和示意图对应图
2. TransitionDown
class TransitionDown(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, nsample=16):
super().__init__()
self.stride, self.nsample = stride, nsample
if stride != 1:
self.linear = nn.Linear(3 + in_planes, out_planes, bias=False)
self.pool = nn.MaxPool1d(nsample)
else:
self.linear = nn.Linear(in_planes, out_planes, bias=False)
self.bn = nn.BatchNorm1d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, pxo):
p, x, o = pxo # (n, 3), (n, c), (b)
if self.stride != 1:
n_o, count = [o[0].item() // self.stride], o[0].item() // self.stride
for i in range(1, o.shape[0]):
count += (o[i].item() - o[i - 1].item()) // self.stride
n_o.append(count)
n_o = torch.cuda.IntTensor(n_o)
idx = pointops.farthest_point_sampling(p, o, n_o) # (m)
n_p = p[idx.long(), :] # (m, 3)
x, _ = pointops.knn_query_and_group(
x,
p,
offset=o,
new_xyz=n_p,
new_offset=n_o,
nsample=self.nsample,
with_xyz=True,
)
x = self.relu(
self.bn(self.linear(x).transpose(1, 2).contiguous())
) # (m, c, nsample)
x = self.pool(x).squeeze(-1) # (m, c)
p, o = n_p, n_o
else:
x = self.relu(self.bn(self.linear(x))) # (n, c)
return [p, x, o]
stride=2的情况下最远点采样,否则按照 linear, bn, relu的顺序进行前向传播
3. TransitionUp
class TransitionUp(nn.Module):
def __init__(self, in_planes, out_planes=None):
super().__init__()
if out_planes is None:
self.linear1 = nn.Sequential(
nn.Linear(2 * in_planes, in_planes),
nn.BatchNorm1d(in_planes),
nn.ReLU(inplace=True),
)
self.linear2 = nn.Sequential(
nn.Linear(in_planes, in_planes), nn.ReLU(inplace=True)
)
else:
self.linear1 = nn.Sequential(
nn.Linear(out_planes, out_planes),
nn.BatchNorm1d(out_planes),
nn.ReLU(inplace=True),
)
self.linear2 = nn.Sequential(
nn.Linear(in_planes, out_planes),
nn.BatchNorm1d(out_planes),
nn.ReLU(inplace=True),
)
def forward(self, pxo1, pxo2=None):
if pxo2 is None:
_, x, o = pxo1 # (n, 3), (n, c), (b)
x_tmp = []
for i in range(o.shape[0]):
if i == 0:
s_i, e_i, cnt = 0, o[0], o[0]
else:
s_i, e_i, cnt = o[i - 1], o[i], o[i] - o[i - 1]
x_b = x[s_i:e_i, :]
x_b = torch.cat(
(x_b, self.linear2(x_b.sum(0, True) / cnt).repeat(cnt, 1)), 1
)
x_tmp.append(x_b)
x = torch.cat(x_tmp, 0)
x = self.linear1(x)
else:
p1, x1, o1 = pxo1
p2, x2, o2 = pxo2
x = self.linear1(x1) + pointops.interpolation(
p2, p1, self.linear2(x2), o2, o1
)
return x
分两种情况:
当只有pxo1的时候,计算pxo1在更小尺寸上的均值,把均值和原始特征拼接。
当有pxo1和pxo2的时候,把pxo2插值倒更大的尺寸,把pxo1和pxo2特征拼接。
4. PointTransformerSeg
self.in_planes, planes = in_channels, [32, 64, 128, 256, 512]
fpn_planes, fpnhead_planes, share_planes = 128, 64, 8
stride, nsample = [1, 4, 4, 4, 4], [8, 16, 16, 16, 16]
网络结构参数设置如上。