pointnet++ / models / pointnet_utils / index_point(point, idx)
index_point() 是pointnet++中的一个方法,其源码如下:
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C] [batch,number,condition]
idx: sample index data, [B, S] [batch,sample]
Return:
new_points:, indexed points data, [B, S, C] [batch,sample,condition]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
作用:
结合输入的总体点云位置数据和筛选出的部分样本点云ID,返回筛选出的部分样本点云位置数据。
也就是说前面的过程中我们获取了目标点云在矩阵中的ID,通过这个函数将ID值转换为各点的(x,y,z)坐标值。
解释:
B代表batch size;
本文中B = 8 ;
B = points.shape[0]
获取points矩阵第一列的维度;
实际上是得到了我们导入的点云数据的第一个维度,也就是batch_size。本例中B = 8
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
view_shape先获得了idx的形状,再将维度减1又限定大小为1,赋值给view_shape后面的列;
这一步是为了还原idx矩阵第一维的大小,并且保持矩阵维度不变,除第一列外个维度大小全部定为1。
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
repeat_shape先获得了idx的形状,再将第一个维度大小限定为1,其他维度大小保持原样不变;
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
先将B排序,然后view成view_shape的形状,这里的形状为[8,1,1];
然后再按repeat_shape的形状对每个维度进行复制,一共复制8次;
最终得到一个和idx大小相同,内容为各个batch序号的张量,例如{tensor(8,512,16)}:就是八组[512,16]的矩阵,从一到八内容分别从0到7;
其目的是建立一个和index格式相同的模板,在下一步根据位置取出points张量中的坐标值。
new_points = points[batch_indices, idx, :]
batch_indices是上一步中构建的各batch的标号;
idx中保存了本次的采样点ID;
:表示要获取对应标号在new_points矩阵中的所有数据,也就是(x,y,z)坐标。
所以本例中最终返回的矩阵就是(8,512,16,3),即8个batch、512个中心采样点、中心点周围16个临近点、每点3坐标值的四维矩阵。