替换一个batch中每个样本的mask位置的元素

替换一个batch中每个样本的mask位置的元素,用gather以及where函数


问题描述

给定一个batch中X,Y两个点云数组,形状分别为(B, N1, 3)和(B, N2, 3) ,计算两个点云的点对点距离,获得A数组中距离小于阈值的点的mask,并将这些点替换为B数组中每个样本所对应的第一个点。
简单点说,就是把A的点云中离B距离近的点都用B中的点替换。

一、获取Y中每个样本的第一个点

X = torch.randn(2,
Y = torch.
dist = torch.cdist(X, Y) * 1000
threshold = 20
min_dist, min_idx = torch.min(dist, dim=2)
mask = min_dist < threshold
batch_indices = torch.zeros(B, dtype=torch.long)
first_points = torch.gather(Y, 1, batch_indices.view(B, 1, 1).expand(-1, 1, 3))

用gather就得到了Y中每个样本的第一个点的坐标

二、把X中mask的点替换成Y中对应样本的第一个点

mask_true_indices = mask.unsqueeze(-1).expand_as(X)
new_Y = torch.where(mask_true_indices, first_points, X)

三、例子

代码如下(示例):

a = torch.randn(2,5,3)
a = tensor([[[ 0.2863,  0.1077, -0.1568],
         [-1.8393, -2.4854, -2.3325],
         [ 1.2281,  0.4034, -1.4813],
         [-2.0025, -0.8318,  1.9351],
         [-0.3946, -2.0275, -1.0836]],

        [[-0.7766, -0.1757,  0.5138],
         [-0.6950, -1.2234, -0.8401],
         [-1.1701,  1.0071, -0.0117],
         [ 0.9544,  0.0314, -1.7266],
         [-0.5719, -0.3843,  1.6186]]])
mask = torch.tensor([[1, 0, 1, 0, 1],
                     [0, 1, 0, 1, 0]])
first = torch.tensor([[[1, 1, 1]],
                     [[0, 0, 0]]])
mask_indices = mask.unsqueeze(-1).enpand_as(a)
mask_indices2 = mask_indices.bool()
final = torch.where(mask_indices2, first, a)
final = tensor([[[ 1.0000,  1.0000,  1.0000],
         [-1.8393, -2.4854, -2.3325],
         [ 1.0000,  1.0000,  1.0000],
         [-2.0025, -0.8318,  1.9351],
         [ 1.0000,  1.0000,  1.0000]],

        [[-0.7766, -0.1757,  0.5138],
         [ 0.0000,  0.0000,  0.0000],
         [-1.1701,  1.0071, -0.0117],
         [ 0.0000,  0.0000,  0.0000],
         [-0.5719, -0.3843,  1.6186]]])

总结

总之,就是用这种方法实现了把A的点云中离B距离近的点都用B中的点替换,同时避免了在网络训练过程中利用for循环而导致的时间浪费的问题。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值