写在前面
最近看代码,发现别人的代码里用到了一个神奇的操作,torch.pdist()
。查阅许久之后,对于他们的描述都不是很明白,遂结合描述,自行测试,结果记录于此,便于理解。
一、文档描述
关于torch.pdist()
的官方文档如下:
计算输入中每对行向量之间的p范数距离。这与
torch.norm(input[:,None]-input, dim=2, p=p])
的上三角部分相同,不包括对角线。如果行上是连续的,这个函数将很快。(Computes the p-norm distance between every pair of row vectors in the input. This is identical to the upper triangular portion, excluding the diagonal, of torch.norm(input[:, None] - input, dim=2, p=p). This function will be faster if the rows are contiguous.)
其实理解之后,对于他的描述才会感觉认同。但是不理解的时候,也看不太懂他的描述。
二、代码测试
我看到的代码如下:
torch.pdist(x, p=2) # 其中x为二维矩阵
因此,为了更好的理解torch.pdist()
,我需要去建立一个简单的二维矩阵,然后根据torch.pdist()
的原理,手写出其计算过程。(PS:之所以建立简单的二维矩阵,就是为了更好理解)
import torch
import numpy as np
_x = np.asarray([[1,2,3],[4,5,6],[7,8,9]])
# print(_x)
# array([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
# 官方的pdist
x = torch.Tensor(_x)
res0 = torch.pdist(x, p=2)
# print(res0)
# tensor([ 5.1962, 10.3923, 5.1962])
# 官方的解释
res1 = torch.norm(x[:,None]-x,dim=2,p=2)
# print(res0)
# tensor([[ 0.0000, 5.1962, 10.3923],
# [ 5.1962, 0.0000, 5.1962],
# [10.3923, 5.1962, 0.0000]])
# 取上三角部分,剔除掉对角线,就是[5.1962, 10.3923, 5.1962],但是这又需要看懂torch.norm(x[:,None]-x,dim=2,p=2)是什么意思
我的理解
文档里讲到,他是算两行之间的p norm,p是参数,容易知道2-范数的公式:
x
=
∣
x
1
−
x
2
∣
2
+
∣
y
1
−
y
2
∣
2
+
∣
z
1
−
z
2
∣
2
x = \sqrt{|x_1 - x_2|^2 + |y_1 - y_2|^2 + |z_1 - z_2|^2}
x=∣x1−x2∣2+∣y1−y2∣2+∣z1−z2∣2
看到公式,计算下两行之间的2-范数,就知道结果了:
5.1962
≊
∣
1
−
4
∣
2
+
∣
2
−
5
∣
2
+
∣
3
−
6
∣
2
5.1962 \approxeq \sqrt{|1 - 4|^2 + |2 - 5|^2 + |3 - 6|^2}
5.1962≊∣1−4∣2+∣2−5∣2+∣3−6∣2
10.3923
≊
∣
1
−
7
∣
2
+
∣
2
−
8
∣
2
+
∣
3
−
9
∣
2
10.3923\approxeq \sqrt{|1-7|^2 + |2-8|^2 + |3-9|^2}
10.3923≊∣1−7∣2+∣2−8∣2+∣3−9∣2
于是,就很明了了。