pytorch和numpy的广播机制类似,搞懂了numpy的广播机制也就明白了pytorch了。
numpy中数组np.array的维数(ndim),维度长度(size), 数组的形状(shape)
不要把维数和形状搞混了
维数也叫做阶数,如果单纯一个序列[...],则是一维,如果以多个这样的序列为元素组成的数组[[...], [...],...,[..]],则是二维数组,2维数组中的元素是常量,如果二维数组的每一个元素用另外的序列表示,则成为三维数组...类似可以构成n维数组。
每个维度是一个矢量,在这个维度上数组的长度叫做维度的size
数组的形状是有维数和各维的长度决定,如三维数组的形状是(s1, s2, s3),s1,s2,s3分别是三个维度的长度
numpy中数组或者高维数组之间的点乘(以及元素与元素的其他运算)不要求拥有相同的形状,通过广播机制可以完成正确运算,但是必须遵循一些规则:
1、两个输入数组能不能运算要看每个维度的长度,如果每个维度长度相同或者维度长度为1,则可以进行运算;
比如输入形状(1,2)可以和输入形状(2, 1)以及(2, 2)进行广播机制的运算,不能和形状(2, 3)进行运算。
2、如果输入数组的维数不同,在维数小的数组的形状前面补充1,例如数组a1的形状是(2,4),数组a的形状是(3, 2, 4),在运算时a1的形状是(1, 2, 4),将后面两维作为整体进行运算
3、如果某个输入数组在某一维度的长度维1,则将该维度的第一个数与其他输入数组在该维度上依次进行运算
示例1:
a = np.array([
[[1, 4, 2, 1], [3, 9, 5, 1]],
[[1, 4, 2, 1], [3, 9, 5, 1]],
[[2, 5, 2, 1], [3, 7, 4, 1]]
])
a1 = np.array([[1, 2, 3, 5], [1, 2, 3, 5]])
c2 = a1 * a
输出:
[[[ 1 8 6 5]
[ 3 18 15 5]]
[[ 1 8 6 5]
[ 3 18 15 5]]
[[ 2 10 6 5]
[ 3 14 12 5]]]
示例2:
a = np.array([
[[1, 4, 2, 1], [3, 9, 5, 1]],
[[1, 4, 2, 1], [3, 9, 5, 1]],
[[2, 5, 2, 1], [3, 7, 4, 1]]
])
b = np.array([[[2, 3, 4, 1]]])
c = a * b
输出:
[[[ 2 12 8 1]
[ 6 27 20 1]]
[[ 2 12 8 1]
[ 6 27 20 1]]
[[ 4 15 8 1]
[ 6 21 16 1]]]