from torch.autograd import Variable
import utils
class STN(nn.Module):
def __init__(self, num_scales=1, num_points=500, dim=3, sym_op='max', quaternion =False):
super(STN, self).__init__()
self.quaternion = quaternion
self.dim = dim
self.sym_op = sym_op
self.num_scales = num_scales
self.num_points = num_points
self.conv1 = torch.nn.Conv1d(self.dim, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.mp1 = torch.nn.MaxPool1d(num_points)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
if not quaternion:
self.fc3 = nn.Linear(256, self.dim*self.dim)
else:
self.fc3 = nn.Linear(256, 4)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
if self.num_scales > 1:
self.fc0 = nn.Linear(1024*self.num_scales, 1024)
self.bn0 = nn.BatchNorm1d(1024)
def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
# symmetric operation over all points
if self.num_scales == 1:
x = self.mp1(x)
else:
if x.is_cuda:
x_scales = Variable(torch.cuda.FloatTensor(x.size(0), 1024*self.num_scales, 1))
else:
x_scales = Variable(torch.FloatTensor(x.size(0), 1024*self.num_scales, 1))
for s in range(self.num_scales):
x_scales[:, s*1024:(s+1)*1024, :] = self.mp1(x[:, :, s*self.num_points:(s+1)*self.num_points])
x = x_scales
x = x.view(-1, 1024*self.num_scales)
if self.num_scales > 1:
x = F.relu(self.bn0(self.fc0(x)))
x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)
if not self.quaternion:
iden = Variable(torch.from_numpy(np.identity(self.dim, 'float32')).clone()).view(1, self.dim*self.dim).repeat(batchsize, 1)
if x.is_cuda:
iden = iden.cuda()
x = x + iden
x = x.view(-1, self.dim, self.dim)
else:
# add identity quaternion (so the network can output 0 to leave the point cloud identical)
iden = Variable(torch.FloatTensor([1, 0, 0, 0]))
if x.is_cuda:
iden = iden.cuda()
x = x + iden
# convert quaternion to rotation matrix
if x.is_cuda:
trans = Variable(torch.cuda.FloatTensor(batchsize, 3, 3))
else:
trans = Variable(torch.FloatTensor(batchsize, 3, 3))
x = utils.batch_quat_to_rotmat(x, trans)
return x
作用:加入了一个四元空间转移网络(STN,Spatial Transformer Network),网络输入点云,输出位姿变换(旋转和平移)的参数,通过这个操作,将原始点云迁移到一个新的有利于网络学习的位姿状态,使用若干层全连接层或一维卷积提取逐点特征并处理,当然,在输出得到的特征时,也对逐点的特征向量做类似的逆变换,到达一个新的位姿。
来源:https://github.com/mrakotosaon/pointcleannet/blob/master/noise_removal/pcpnet.py