广播(broadcasting)指的是不同形状的数组之间的算术运算的执行方式。它是一种非常强大的功能。将标量跟数组合并时就会发生最简单的广播。
广播的原则如果两个数组的后缘维度(trailing dimension, 即从末尾开始算起的维度)的轴长度相符或者其中一方长度为1,则认为它们是广播兼容的。广播会在缺失和(或)长度为1的维度上进行。
import numpy as np
from numpy.random import randn
arr = np.arange(5)
print(arr)
#out[]:[0 1 2 3 4]
print(arr * 4)
#out[]:[ 0 4 8 12 16]
#在这个乘法运算中,标量值4被广播到了其他所有的元素上
#通过减去列平均值的方式对数组的每一列进行行距平化处理
arr = randn(4, 3)
print(arr.mean(0))
#out[]:
# [-0.19761277 -0.01223537 -0.13039332]
demeaned = arr - arr.mean(0)
print(demeaned)
#out[]:
# [[-0.08902603 0.21504452 1.44062124]
# [-1.50993883 1.10333762 -0.34120107]
# [ 0.77850955 -0.14465187 -0.75158805]
# [ 0.82045531 -1.17373026 -0.34783212]]
#对数组的每一行进行行距平化处理
arr = randn(4, 3)
row_means = arr.mean(1)
print(row_means)
#out[]:[ 0.40879942 -0.36268125 -0.15265728 -0.34711618]
print(row_means.reshape((4, 1)))
#out[]:
# [[ 0.40879942]
# [-0.36268125]
# [-0.15265728]
# [-0.34711618]]
demeaned = arr - row_means.reshape((4,1))
print(demeaned.mean(1))
#out[]:[ 0.00000000e+00 -7.40148683e-17 0.00000000e+00 -1.85037171e-17]
demeaned = arr - row_means.reshape((4,1))
需要专门为了广播而添加一个长度为1的新轴,虽然reshape是一个办法,但插入轴时需要构造一个表示新形状的元组,所以可以通过np.newaxis属性以及“全”切片来插入新轴。
arr = np.zeros((4,4))
arr_3d = arr[:, np.newaxis, :]
print(arr_3d.shape)
#out[]:(4, 1, 4)
arr_1d = np.random.normal(size=3)
print(arr_1d[:, np.newaxis])
# [[ 0.2219978 ]
# [ 0.09997794]
# [-0.96239615]]
print(arr_1d[np.newaxis, :])
# [[ 0.2219978 0.09997794 -0.96239615]]
如果我们有一个三维数组,并希望对轴2进行距平化,那么只需要编写下面这样的代码就可以了
arr = randn(3, 4, 5)
depth_means = arr.mean(2)
demeaned = arr - depth_means[:, :, np.newaxis]
算术运算所遵循的广播原则同样也适用于通过索引机制设置数组值的操作。
arr = np.zeros((4, 3))
arr[:] = 5
print(arr)
#out[]:
# [[5. 5. 5.]
# [5. 5. 5.]
# [5. 5. 5.]
# [5. 5. 5.]]
假设我们想要用一个一维数组来设置目标数组的各列。只要保证形状兼容就可以了。
arr = np.zeros((4, 3))
col = np.array([1.28, -0.42, 0.44, 1.6])
arr[:] = col[:, np.newaxis]
print(arr)
#out[]:
# [[ 1.28 1.28 1.28]
# [-0.42 -0.42 -0.42]
# [ 0.44 0.44 0.44]
# [ 1.6 1.6 1.6 ]]
arr[: 2] = [[-1.37], [0.509]]
print(arr)
#out[]:
# [[-1.37 -1.37 -1.37 ]
# [ 0.509 0.509 0.509]
# [ 0.44 0.44 0.44 ]
# [ 1.6 1.6 1.6 ]]