N(xn,yn,zn)与M(xm,ym,zm)的欧式距离:
距离平方:
def square_distance(src, dst):
‘’‘
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M] 返回点平方距离
‘’‘
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
dst.permute将张量维度从[B, M, C]变成[B, C,M],便于矩阵乘法运算,torch.matmul()计算,并乘以-2,即矩阵计算后的张量为[B,N,M]。
torch.sum(src**2,-1)用来计算:,计算后张量为[B,N,1],并与第1步的张量[B,N,M]相加,即。相加(两个张量维度都是3,且对应轴的值需一样或者为1;相加时将为1的轴进行复制扩充,得到两个维度完全相同的重量;然后对应位置相加即可)后的张量维数为[B,N,M]。
torch.sum(dst**2,-1)用来计算:,计算后张量为[B,1,M],并与第2步的张量[B,N,M]相加,即。相加后的张量维数为[B,N,M]。
由此计算出点N与M之间的距离平方。