numpy的universal function处理数据时,要求输入数组的shape必须一致,当数组的shape不一致时,则会产生广播机制;
广播机制会调整shape,使数组运算满足规则。
广播机制在调整ndarray时的四条规则:
1 让所有输入数组都向其中shape最长的ndarray看齐,shape中不足的部分都通过在前面添加1补齐
2 输出数组的shape是输入数组shape的各个轴上的最大值
3 如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错
4 当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值
实例一
import numpy as np a = np.array([[10,10,10],[20,20,20],[30,30,30]]) print(a.shape) #(3, 3) print(a) # [[10 10 10] # [20 20 20] # [30 30 30]] b = np.array([1,2,3]) print(b.shape) #(3,) print(b) # [1 2 3] c = a + b print(c.shape) #(3, 3) print(c) # [[11 12 13] # [21 22 23] # [31 32 33]]
解析
1 这里最长的是a,shape=(3,3),b的shape为1行3列 (3,),对于a而言,b 的行是不足的,因此补足后为
b.shape = 1,3 print(b.shape) #(1, 3) print(b) #[[1 2 3]]
2 输出数组的shape为输入数组shape在各轴上的最大值,也即(3,3)
4 当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值,如下图所示
实例二
import numpy as np a = np.array([[0,0,0],[10,10,10],[20,20,20],[30,30,30]]) print(a.shape) #(4, 3) print(a) # [[ 0 0 0] # [10 10 10] # [20 20 20] # [30 30 30]] b = np.array([1,2,3]) print(b.shape) #(3,) print(b) # [1 2 3] c = a + b print(c.shape) #(4, 3) print(c) # [[ 1 2 3] # [11 12 13] # [21 22 23] # [31 32 33]]
该示例中最长的也是a,a.shape = (4,3),b.shape = (3,),所以可以b的行是不足的;同样需要将其补足(补足时通过在其前面加1实现)也即b.shape = (1,3),也即 1 行 3 列;
输出组的shape为输入组的最大值,即(4,3),
此时可以得到最终结果。
示例三
import numpy as np a = np.array([[0,0,0],[10,10,10],[20,20,20],[30,30,30]]) print(a.shape) #(4, 3) b = np.array([1,2,3,4]) print(b.shape) #(,) c = a + b print(c.shape) #无法广播
无法广播,运行失败
ValueError: operands could not be broadcast together with shapes (4,3) (4,)
也就是说,列数不同的ndarray是不能广播的。
示例四
import numpy as np a = np.array([[0,0,0],[10,10,10],[20,20,20],[30,30,30]]) print(a.shape) #(4, 3) print(a) # [[ 0 0 0] # [10 10 10] # [20 20 20] # [30 30 30]] b = np.array([1,2,3,4]).reshape(-1,1) print(b.shape) #(4, 1) print(b) # [[1] # [2] # [3] # [4]] c = a + b print(c.shape) #(4, 3) print(c) # [[ 1 1 1] # [12 12 12] # [23 23 23] # [34 34 34]]
实现方式
从示例二、三、四对比可以发现,广播可以实现ndarray.shape值不相等的相加,但是必须在一个维度是相等的,示例三中两个维度均不相等,就报错了。
此外,经过尝试,广播运算还可以进行 加+ 、减- 、乘* 、除/ 运算
参考:NumPy广播、NumPy的详细教程(官网手册翻译)、2.2.1 广播