PointTransformerBlock
class PointTransformerBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, share_planes=8, nsample=16):
super(PointTransformerBlock, self).__init__()
self.linear1 = nn.Linear(in_planes, planes, bias=False)
self.bn1 = nn.BatchNorm1d(planes)
self.transformer2 = PointTransformerLayer(planes, planes, share_planes, nsample)
self.bn2 = nn.BatchNorm1d(planes)
self.linear3 = nn.Linear(planes, planes * self.expansion, bias=False)
self.bn3 = nn.BatchNorm1d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
def forward(self, pxo):
p, x, o = pxo # (n, 3), (n, c), (b)
identity = x
x = self.relu(self.bn1(self.linear1(x)))
x = self.relu(self.bn2(self.transformer2([p, x, o])))
x = self.bn3(self.linear3(x))
x += identity
x = self.relu(x)
return [p, x, o]
先线性变换再PointTransformerLayer,有一支identity通路。
PointTransformerSeg
class PointTransformerSeg(nn.Module):
def __init__(self, block, blocks, c=6, k=13):
super().__init__()
self.c = c
self.in_planes, planes = c, [32, 64, 128, 256, 512]
fpn_planes, fpnhead_planes, share_planes = 128, 64, 8
stride, nsample = [1, 4, 4, 4, 4], [8, 16, 16, 16, 16]
self.enc1 = self._make_enc(block, planes[0], blocks[0], share_planes, stride=stride[0], nsample=nsample[0]) # N/1
self.enc2 = self._make_enc(block, planes[1], blocks[1], share_planes, stride=stride[1], nsample=nsample[1]) # N/4
self.enc3 = self._make_enc(block, planes[2], blocks[2], share_planes, stride=stride[2], nsample=nsample[2]) # N/16
self.enc4 = self._make_enc(block, planes[3], blocks[3], share_planes, stride=stride[3], nsample=nsample[3]) # N/64
self.enc5 = self._make_enc(block, planes[4], blocks[4], share_planes, stride=stride[4], nsample=nsample[4]) # N/256
self.dec5 = self._make_dec(block, planes[4], 2, share_planes, nsample=nsample[4], is_head=True) # transform p5
self.dec4 = self._make_dec(block, planes[3], 2, share_planes, nsample=nsample[3]) # fusion p5 and p4
self.dec3 = self._make_dec(block, planes[2], 2, share_planes, nsample=nsample[2]) # fusion p4 and p3
self.dec2 = self._make_dec(block, planes[1], 2, share_planes, nsample=nsample[1]) # fusion p3 and p2
self.dec1 = self._make_dec(block, planes[0], 2, share_planes, nsample=nsample[0]) # fusion p2 and p1
self.cls = nn.Sequential(nn.Linear(planes[0], planes[0]), nn.BatchNorm1d(planes[0]), nn.ReLU(inplace=True), nn.Linear(planes[0], k))
def _make_enc(self, block, planes, blocks, share_planes=8, stride=1, nsample=16):
layers = []
layers.append(TransitionDown(self.in_planes, planes * block.expansion, stride, nsample))
self.in_planes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_planes, self.in_planes, share_planes, nsample=nsample))
return nn.Sequential(*layers)
def _make_dec(self, block, planes, blocks, share_planes=8, nsample=16, is_head=False):
layers = []
layers.append(TransitionUp(self.in_planes, None if is_head else planes * block.expansion))
self.in_planes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_planes, self.in_planes, share_planes, nsample=nsample))
return nn.Sequential(*layers)
def forward(self, pxo):
p0, x0, o0 = pxo # (n, 3), (n, c), (b)
x0 = p0 if self.c == 3 else torch.cat((p0, x0), 1)
p1, x1, o1 = self.enc1([p0, x0, o0])
p2, x2, o2 = self.enc2([p1, x1, o1])
p3, x3, o3 = self.enc3([p2, x2, o2])
p4, x4, o4 = self.enc4([p3, x3, o3])
p5, x5, o5 = self.enc5([p4, x4, o4])
x5 = self.dec5[1:]([p5, self.dec5[0]([p5, x5, o5]), o5])[1]
x4 = self.dec4[1:]([p4, self.dec4[0]([p4, x4, o4], [p5, x5, o5]), o4])[1]
x3 = self.dec3[1:]([p3, self.dec3[0]([p3, x3, o3], [p4, x4, o4]), o3])[1]
x2 = self.dec2[1:]([p2, self.dec2[0]([p2, x2, o2], [p3, x3, o3]), o2])[1]
x1 = self.dec1[1:]([p1, self.dec1[0]([p1, x1, o1], [p2, x2, o2]), o1])[1]
x = self.cls(x1)
return x