numpy广播机制
numpy 在算术运算期间采用“广播”来处理具有不同形状的 array ,即将较小的阵列在较大的阵列上“广播”,以便它们具有兼容的形状。
如下例
>>> a = np.array([1.0, 2.0, 3.0])
>>> b = 2.0
>>> a * b
array([2., 4., 6.])
标量 b 在算术运算期间被拉伸成一个与 a 形状相同的 array ,其中的新元素,是原始标量b的副本,拉伸类比只是概念性的。 在“广播”中 numpy 使用的是原始标量值b而不产生额外的数据副本。
广播规则
当对两个 array 进行操作时,numpy 会逐元素比较它们的形状。从尾(即最右边)维度开始,然后向左逐渐比较。只有当两个维度 1)相等
or 2)其中一个维度是1
时,这两个维度才会被认为是兼容。
如果不满足这些条件,则会抛出 ValueError:operands could not be broadcast together 异常,表明 array 的形状不兼容。最终结果 array 的每个维度尽可能不为 1 ,是两个操作数各个维度中较大的值 。
例如,有一个 256x256x3 的 RGB 值图片 array ,需要将图像中的每种颜色缩放不同的值,此时可以将图像乘以具有 3 个值的一维 array 。根据广播规则排列这两个 array 的尾维度大小,是兼容的:
图片(3d array): 256 x 256 x 3
缩放(1d array): 3
结果(3d array): 256 x 256 x 3
当比较的任一维度是 1 时,使用另一个,也就是说,大小为 1 的维度被拉伸或“复制”以匹配另一个维度。
在以下示例中,A 和 B 数组都有长度为 1 的维度,在广播操作期间扩展为更大的大小:
A (4d array): 8 x 1 x 6 x 1
B (3d array): 7 x 1 x 5
result (4d array): 8 x 7 x 6 x 5
以二维为例,更加方便的解释“广播”:
已知 a.shape
是(5,1),b.shape
是(1,6),c.shape
是(6,),d.shape
是(), d 是一个标量, a, b, c,和 d 都可以“广播”到维度 (5,6);
- a “广播”为一个 (5,6) array ,其中 a[:,0] 被“广播”到其他列,
- b “广播”为一个 (5,6) array ,其中 b[0,:] 被广播到其他行,
- c 类似于 (1,6) array ,其中 c[:] 广播到每一行,
- d 是标量,“广播”为 (5,6) array ,其中每个元素都一样,重复d值。
以下是不兼容的例子:
A (2d array): 2 x 1
B (3d array): 8 x 4 x 3 # 倒数第二个维度不兼容
>>> a = np.array([[ 0.0, 0.0, 0.0],
... [10.0, 10.0, 10.0],
... [20.0, 20.0, 20.0],
... [30.0, 30.0, 30.0]])
>>> b = np.array([1.0, 2.0, 3.0])
>>> a + b
array([[ 1., 2., 3.],
[11., 12., 13.],
[21., 22., 23.],
[31., 32., 33.]])
>>> b = np.array([1.0, 2.0, 3.0, 4.0])
>>> a + b
Traceback (most recent call last):
ValueError: operands could not be broadcast together with shapes (4,3) (4,)
在某些情况下,广播会拉伸两个 array 以形成一个大于任何一个初始 array 的结果 array 。
>>> a = np.array([0.0, 10.0, 20.0, 30.0])
>>> b = np.array([1.0, 2.0, 3.0])
>>> a[:, np.newaxis] + b
array([[ 1., 2., 3.],
[11., 12., 13.],
[21., 22., 23.],
[31., 32., 33.]])
newaxis
运算符将新轴插入到 a 中,使其成为二维 4x1 array 。将 4x1 array 与形状为 (3,) 的 b 组合,产生一个 4x3 array 。
现实应用
一个典型的例子是信息论、分类和其他相关领域中使用的矢量量化(VQ Vector Quantization)算法。 VQ 中的基本操作是在一组点(在 VQ 术语中称为codes
)中找到最接近给定点的点,称为观察点observation
。在下面显示的非常简单的二维案例中,观察值observation
描述了要分类的运动员的体重和身高。这些codes
代表不同类别的运动员。 找到最近的点需要计算观察点observation和每个codes之间的距离,最短距离是最佳匹配。在此示例中,codes[0]
是最接近的类别,表明该运动员可能是一名篮球运动员。
>>> from numpy import array, argmin, sqrt, sum
>>> observation = array([111.0, 188.0])
>>> codes = array([[102.0, 203.0],
... [132.0, 193.0],
... [45.0, 155.0],
... [57.0, 173.0]])
>>> diff = codes - observation # 这里发生“广播”
>>> dist = sqrt(sum(diff**2,axis=-1))
>>> argmin(dist)
0
在此示例中,observation
数组被拉伸以此来匹配codes
array的形状:
通常,将可能从数据库中读取的大量观察observation
结果与一组codes
进行比较:
Observation (2d array): 10 x 3
Codes (2d array): 5 x 3
Diff (3d array): 5 x 10 x 3
三维数组 diff
是广播的结果,不是必需的。大型数据集往往会生成计算效率低下的大型中间 array 。但是,如果使用 Python 循环围绕上述二维示例中的codes单独计算每个观察值observation,则使用更小的 array。
总结
广播是编写简短直观代码的强大工具,可以非常有效地进行计算。但是,在某些情况下,广播会为特定算法使用不必要的大量内存。在这些情况下,最好用 Python 编写算法的外循环,这也可能产生更具可读性的代码,因为使用广播的算法往往会随着广播中维度数量的增加而变得更难以解释。
https://numpy.org/doc/stable/user/basics.broadcasting.html