替换一个batch中每个样本的mask位置的元素,用gather以及where函数
问题描述
给定一个batch中X,Y两个点云数组,形状分别为(B, N1, 3)和(B, N2, 3) ,计算两个点云的点对点距离,获得A数组中距离小于阈值的点的mask,并将这些点替换为B数组中每个样本所对应的第一个点。
简单点说,就是把A的点云中离B距离近的点都用B中的点替换。
一、获取Y中每个样本的第一个点
X = torch.randn(2,
Y = torch.
dist = torch.cdist(X, Y) * 1000
threshold = 20
min_dist, min_idx = torch.min(dist, dim=2)
mask = min_dist < threshold
batch_indices = torch.zeros(B, dtype=torch.long)
first_points = torch.gather(Y, 1, batch_indices.view(B, 1, 1).expand(-1, 1, 3))
用gather就得到了Y中每个样本的第一个点的坐标
二、把X中mask的点替换成Y中对应样本的第一个点
mask_true_indices = mask.unsqueeze(-1).expand_as(X)
new_Y = torch.where(mask_true_indices, first_points, X)
三、例子
代码如下(示例):
a = torch.randn(2,5,3)
a = tensor([[[ 0.2863, 0.1077, -0.1568],
[-1.8393, -2.4854, -2.3325],
[ 1.2281, 0.4034, -1.4813],
[-2.0025, -0.8318, 1.9351],
[-0.3946, -2.0275, -1.0836]],
[[-0.7766, -0.1757, 0.5138],
[-0.6950, -1.2234, -0.8401],
[-1.1701, 1.0071, -0.0117],
[ 0.9544, 0.0314, -1.7266],
[-0.5719, -0.3843, 1.6186]]])
mask = torch.tensor([[1, 0, 1, 0, 1],
[0, 1, 0, 1, 0]])
first = torch.tensor([[[1, 1, 1]],
[[0, 0, 0]]])
mask_indices = mask.unsqueeze(-1).enpand_as(a)
mask_indices2 = mask_indices.bool()
final = torch.where(mask_indices2, first, a)
final = tensor([[[ 1.0000, 1.0000, 1.0000],
[-1.8393, -2.4854, -2.3325],
[ 1.0000, 1.0000, 1.0000],
[-2.0025, -0.8318, 1.9351],
[ 1.0000, 1.0000, 1.0000]],
[[-0.7766, -0.1757, 0.5138],
[ 0.0000, 0.0000, 0.0000],
[-1.1701, 1.0071, -0.0117],
[ 0.0000, 0.0000, 0.0000],
[-0.5719, -0.3843, 1.6186]]])
总结
总之,就是用这种方法实现了把A的点云中离B距离近的点都用B中的点替换,同时避免了在网络训练过程中利用for循环而导致的时间浪费的问题。