代码地址:https://github.com/POSTECH-CVLab/point-transformer/blob/master/model/pointtransformer/pointtransformer_seg.py
class TransitionDown(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, nsample=16):
super().__init__()
self.stride, self.nsample = stride, nsample
if stride != 1:
self.linear = nn.Linear(3+in_planes, out_planes, bias=False)
self.pool = nn.MaxPool1d(nsample)
else:
self.linear = nn.Linear(in_planes, out_planes, bias=False)
self.bn = nn.BatchNorm1d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, pxo):
p, x, o = pxo # (n, 3), (n, c), (b)
if self.stride != 1:
n_o, count = [o[0].item() // self.stride], o[0].item() // self.stride
for i in range(1, o.shape[0]):
count += (o[i].item() - o[i-1].item()) // self.stride
n_o.append(count)
n_o = torch.cuda.IntTensor(n_o)
idx = pointops.furthestsampling(p, o, n_o) # (m)
n_p = p[idx.long(), :] # (m, 3)
x = pointops.queryandgroup(self.nsample, p, n_p, x, None, o, n_o, use_xyz=True) # (m, 3+c, nsample)
x = self.relu(self.bn(self.linear(x).transpose(1, 2).contiguous())) # (m, c, nsample)
x = self.pool(x).squeeze(-1) # (m, c)
p, o = n_p, n_o
else:
x = self.relu(self.bn(self.linear(x))) # (n, c)
return [p, x, o]
相当于最远距离采样+maxpool,中间有一些线性转换和非线形激活层。
class TransitionUp(nn.Module):
def __init__(self, in_planes, out_planes=None):
super().__init__()
if out_planes is None:
self.linear1 = nn.Sequential(nn.Linear(2*in_planes, in_planes), nn.BatchNorm1d(in_planes), nn.ReLU(inplace=True))
self.linear2 = nn.Sequential(nn.Linear(in_planes, in_planes), nn.ReLU(inplace=True))
else:
self.linear1 = nn.Sequential(nn.Linear(out_planes, out_planes), nn.BatchNorm1d(out_planes), nn.ReLU(inplace=True))
self.linear2 = nn.Sequential(nn.Linear(in_planes, out_planes), nn.BatchNorm1d(out_planes), nn.ReLU(inplace=True))
def forward(self, pxo1, pxo2=None):
if pxo2 is None:
_, x, o = pxo1 # (n, 3), (n, c), (b)
x_tmp = []
for i in range(o.shape[0]):
if i == 0:
s_i, e_i, cnt = 0, o[0], o[0]
else:
s_i, e_i, cnt = o[i-1], o[i], o[i] - o[i-1]
x_b = x[s_i:e_i, :]
x_b = torch.cat((x_b, self.linear2(x_b.sum(0, True) / cnt).repeat(cnt, 1)), 1)
x_tmp.append(x_b)
x = torch.cat(x_tmp, 0)
x = self.linear1(x)
else:
p1, x1, o1 = pxo1; p2, x2, o2 = pxo2
x = self.linear1(x1) + pointops.interpolation(p2, p1, self.linear2(x2), o2, o1)
return x
上采样和求和。