我看的是这位大佬的代码:
T-Net网络解析
class STN3d(nn.Module):
def __init__(self, channel):
super(STN3d, self).__init__()
self.conv1 = torch.nn.Conv1d(channel, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 9)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
def forward(self, x):
batchsize = x.size()[0] # 24x3x1024
x = F.relu(self.bn1(self.conv1(x)))# 24x64x1024
x = F.relu(self.bn2(self.conv2(x)))# 24x128x1024
x = F.relu(self.bn3(self.conv3(x)))# 24x1024x1024
x = torch.max(x, 2, keepdim=True)[0]# 24x1024x1
x = x.view(-1, 1024) #24x1024
x = F.relu(self.bn4(self.fc1(x))) #24x512
x = F.relu(self.bn5(self.fc2(x))) #24x256
x = self.fc3(x) #24x9
iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
batchsize, 1)
if x.is_cuda:
iden = iden.cuda()
x = x + iden #
x = x.view(-1, 3, 3) #3x3 transform 24 x 3 x 3
return x
训练使用的数据集modelnet数据集,每批训练的batchsize是24
点云数据最开始的格式是24x1024x3
我们将其转化成24x3x1024的形式输入U-net网络
class PointNetEncoder(nn.Module):
def __init__(self, global_feat=True, feature_transform=False, channel=3):
super(PointNetEncoder, self).__init__() # 2
self.stn = STN3d(channel) # 24 x 3 x 3
self.conv1 = torch.nn.Conv1d(channel, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.global_feat = global_feat
self.feature_transform = feature_transform
if self.feature_transform:
self.fstn = STNkd(k=64)
def forward(self, x):
B, D, N = x.size() # 24x3x1024
trans = self.stn(x) # trans 24x3x3
x = x.transpose(2, 1)# 24x1024x3 z he y jiaohhuan
if D > 3:
feature = x[:, :, 3:]
x = x[:, :, :3]
x = torch.bmm(x, trans) # 24x1024x3
if D > 3:
x = torch.cat([x, feature], dim=2)
x = x.transpose(2, 1) # 24x3x1024
x = F.relu(self.bn1(self.conv1(x)))#24x64x1024
if self.feature_transform:
trans_feat = self.fstn(x) #24x64x64
x = x.transpose(2, 1)# 24x1024x64
x = torch.bmm(x, trans_feat)# 24x1024x64
x = x.transpose(2, 1)# 24x64x1024
else:
trans_feat = None
pointfeat = x # 24x64x1024
x = F.relu(self.bn2(self.conv2(x)))# 24x128x1024
x = self.bn3(self.conv3(x))# 24x1024x1024
x = torch.max(x, 2, keepdim=True)[0]# 24x1024x1
x = x.view(-1, 1024)# 24x1024
if self.global_feat:
return x, trans, trans_feat #24x1024 24x3x3 24x64x64
else:
x = x.view(-1, 1024, 1).repeat(1, 1, N)
return torch.cat([x, pointfeat], 1), trans, trans_feat
trans我们通过U-net网络生成一个24x3x3的张量,
将原来的数据24x3x1024通过transpose(2,1)操作转化成24x1024x3
然后通过torch.bmm()与原来的数据进行matrix multiply操作
torch.bmm(input, mat2, *, out=None) → Tensor
If input is a (b×n×m) tensor, mat2 is a (b×m×p) tensor, out will be a (b×n×p)tensor.
然后通过x.transpose(2,1)继续将数据转化成24x3x1024的形式
之后经过一层卷积,归一化,以及relu操作得到24x64x1024的数据(64通道)
接下来就是另一个feature_transform
通过STNkd(k=64)与上面身成3✖3矩阵的操作类似,最后生成一个26x64x64的矩阵
然后同样先进行transpose(2,1)操作生成24x1024x64的数据,然后通过torch.bmm操作
生成24x1024x64数据的操作,最后再进行一次transpose(2,1)操作