#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
def by_l2_square(a, b):
a = a.unsqueeze(-2)
b = b.unsqueeze(-3)
C = torch.sum((a - b) ** 2, dim=-1)
return C
def by_p_norm(a, b, p):
a = a.unsqueeze(-2)
b = b.unsqueeze(-3)
C = torch.linalg.vector_norm(a - b, ord=p, dim=-1)
return C
def origin(a, b):
B, n, _ = a.shape
_, m, _ = b.shape
C = torch.empty(B, n, m)
for batch in range(B):
for i in range(n):
for j in range(m):
C[batch, i, j] = torch.sum((a[batch, i] - b[batch, j])**2)
return C
if __name__ == '__main__':
n = 5
m = 4
a = torch.rand(1, n, 2)
b = torch.rand(1, m, 2)
ans1 = by_l2_square(a, b)
ans2 = by_p_norm(a, b, 2)**2
ans3 = origin(a, b)
print(torch.allclose(ans1, ans3))
print(torch.allclose(ans1, ans2))
print(torch.allclose(ans2, ans3))
坐标转距离矩阵
最新推荐文章于 2023-11-04 16:00:55 发布