Numpy 广播规则:两个数组的形状即 shape,从后往前查看,要么严格相等,要么其中一个数组的在当前查看的维度上的长度为1,或者维度缺失,这样就能匹配,(并在相应维度上做拷贝扩充,但只是概念上假想的拷贝,而不是真正意义上内存上的拷贝),满足广播条件,否则不满足广播条件,程序报错。
广播(Broadcast)是 numpy 对不同形状(shape)的数组进行数值计算的方式, 对数组的算术运算通常在相应的元素上进行。
import numpy as np
a = np.array([[1,2,3,4],[5,6,7,8]])
b = np.array([[10,20,30,40],[50,60,70,80]])
print(a,a.shape)
print(b,b.shape)
print(a+b,(a+b).shape)
广播实列
import numpy as np
a = np.array([[ 0, 0, 0],
[10,10,10],
[20,20,20],
[30,30,30]])
b = np.array([1,2,3])
print(a,a.shape)
print(b,b.shape)
print(a+b,(a+b).shape)
广播的规则:
- 让所有输入数组都向其中形状最长的数组看齐,形状中不足的部分都通过在前面加 1 补齐。
- 输出数组的形状是输入数组形状的各个维度上的最大值。
- 如果输入数组的某个维度和输出数组的对应维度的长度相同或者其长度为 1 时,这个数组能够用来计算,否则出错。
- 当输入数组的某个维度的长度为 1 时,沿着此维度运算时都用此维度上的第一组值。
简单理解:对两个数组,分别比较他们的每一个维度(若其中一个数组没有当前维度则忽略),满足:
- 数组拥有相同形状。
- 当前维度的值相等。
- 当前维度的值有一个是 1。
若条件不满足,抛出 “ValueError: frames are not aligned” 异常。
PyTorch 广播规则:两个张量的 shape 从后往前逐个检查,要么对应维度相等,要么其中一个的维度上长度是1,或者其中一个的维度缺失,则在相应维度上扩充,使得和另一个张量的对应维度上的长度相等。否则广播失败,程序报错。
import torch
x=torch.empty(5,1,4,1)
y=torch.empty(3,1,1)
print(x.shape)
print(y.shape)
print((x+y).shape)
import torch
x=torch.empty(5,1,4,1)
y=torch.empty(3,2,1)
print(x.shape)
print(y.shape)
print((x+y).shape)