pytorch中计算两点之间的距离

pytorch中计算两点之间的距离实现及代码解读
代码如下:

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src * dst^T = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; [B, N, 1]
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; [B, M, 1]
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src*dst^T

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) #permute为转置,[B, N, M]
    dist += torch.sum(src ** 2, -1).view(B, N, 1) #[B, N, M] + [B, N, 1],dist的每一列都加上后面的列值
    dist += torch.sum(dst ** 2, -1).view(B, 1, M) #[B, N, M] + [B, 1, M],dist的每一行都加上后面的行值
    return dist

这里计算的是原点(src)集合中N个点到目标点(dst)集合中M点的距离(平方距离,没有开根号),以Batch为单位,输出B×N×M的tensor。这里的实现比较巧妙(实际上就是因式分解):
标准的计算距离公式为:
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
=(xnxn + ynyn + znzn) + (xmxm + ymym + zmzm) - 2*(xn * xm + yn * ym + zn * zm)
第一项和第二项直接是两个点集坐标矩阵自身的平方和:
sum(src^2, dim=-1) = xnxn + ynyn + znzn; [B, N, 1]
sum(dst^2, dim=-1) = xm
xm + ymym + zmzm; [B, M, 1]
第三项为src矩阵和dst矩阵的转置进行矩阵乘:
src * dst^T = xn * xm + yn * ym + zn * zm;[B, N, M]
因此最终计算公式为:
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src2,dim=-1)+sum(dst2,dim=-1)-2src^Tdst
注释写的很详细了,需要注意的是最后两行,按理说不同形状的矩阵不能加减,但经过验证发现,这里的实现是dist的每一列都加上后面的列值,每一行都加上后面的行值。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值