python广播机制(计算一个数组中所有点之间的距离)
今天遇到了需要计算一个数组中的所有n维点之间的距离(使用欧氏距离),使用两个for循环的话就会很慢,而使用广播机制来计算的话就会很快。这篇文章主要是阐述一下广播机制的原理(不再阐述广播需要的条件)。
使用两个for循环的python代码是:
distance=np.zeros((N,N))
def loop(coord):
# add your codes here
for i in range(coord.shape[0]):
for j in range(coord.shape[0]):
distance[i,j]=dis(coord[i],coord[j])
return distance
例如有一个3x2的数组,即有三个点,每一个点的坐标用x和y表示,现在想用一个矩阵表示这三个点两两之间的距离,即要输出一个3x3矩阵。用字母来表示,即有一个axb的数组,a表示数组的点的数量,b表示点的坐标的维度。
将上面原始数组复制成两份,一份在a的后面且b的前面扩展一个维度,使其维度变为ax1xb,这一步能让b这个维度上的数据在之后广播的时候扩展;另一份在a的前面扩展一个维度,变成1xaxb。
那么广播的时候两个数组都被广播成axaxb了,那么这个时候将两个数组相减
并且平方
,之后再在最后一个维度上求和
后开方
就可以得到结果。不过可能比较难理解的是,为什么这样就能做到两两相减?
当axb扩展成1xaxb再扩展成axaxb时(以3x3x2为例),这是在将原来的3x2数组整个复制三份,即扩展成原来的三倍,可以看成是在行方向扩展,即每一行是相同的,也就是
当3x2扩展成3x1x2再扩展成3x3x2时,是将2
这一维的数据复制了三份,即扩展成原来的三倍,可以看成是在列方向上扩展,即每一列都是相同的,也就是
这样的话有一个扩展后的数组的一列都是一样的,而另一个扩展后的数组的每一列是所有的点,两个数组相减就完成了计算所有点两两之间的距离了。
使用广播机制的python代码是:
def broadcast(coord):
# add your codes here
distance=np.sqrt(np.sum((coord[:,np.newaxis,:]-coord[np.newaxis,:,:])**2,-1))
return distance