什么是广播运算?
1、概念
广播运算初看有点神秘,其实它类似于分数加减法的通分。通分用于解决分母不同的分数加减运算问题,它利用分数基本性质把分母调整为原来分母的最小公倍数。广播用于解决不同形状数组的运算问题,包括加法、减法、点乘(对应元素相乘)、点除(对应元素相除乘)、比较等,如对列向量a和行向量b做加法运算,
a
+
b
=
[
1
2
]
+
[
3
4
]
a+b= \left[ \begin{matrix}1\\2 \end{matrix} \right]+\left[\begin{matrix}3&4 \end{matrix} \right]
a+b=[12]+[34]
由于二者形状不同,a+b在代数上是行不通的。但是python通过广播操作,把a、b的形状统一为
2
×
2
2\times2
2×2 的数组之后,就能够执行加法运算了。
a
+
b
=
[
1
1
2
2
]
+
[
3
4
3
4
]
=
[
4
5
5
6
]
a+b=\left[\begin{matrix}1&1\\2&2 \end{matrix} \right] + \left[\begin{matrix}3&4\\3&4 \end{matrix} \right]=\left[\begin{matrix}4&5\\5&6 \end{matrix} \right]
a+b=[1212]+[3344]=[4556]
a = torch.tensor([[1],[2]])
print(a)
b = torch.tensor([3,4])
print(b)
print(a+b)
结果
[[4 5]
[5 6]]
广播运算有什么用?如果要从a+b中选择大于5的元素,一种写法是不用广播的,
a
+
b
>
[
5
5
5
5
]
a+b>\left[\begin{matrix} 5&5\\ 5&5\\ \end{matrix}\right]
a+b>[5555]
另一种是使用广播的(广播把自动5变为
2
×
2
2\times2
2×2 的矩阵),
a
+
b
>
5
a+b>5
a+b>5
您认为哪种写法更加自然?
2、规则
广播统一两个数组的维度,具体做法有两步,分别起到统一维度和统一长度的作用。
(1)将低维变为高维
如
a
a
a 是二维的,
a
=
[
1
2
3
4
5
6
]
a=\left[\begin{matrix} 1&2&3\\ 4&5&6 \end{matrix} \right]
a=[142536]
b
b
b 是一维的,
b
=
[
1
2
3
]
b=\left[\begin{matrix}1&2&3\end{matrix}\right]
b=[123]
广播首先将
b
b
b 变为二维数组
b
=
[
[
1
2
3
]
]
b=\left[\left[\begin{matrix} 1&2&3 \end{matrix} \right]\right]
b=[[123]]
使得
a
、
b
a、b
a、b都是二维数组,
(2)将1维变为n维
现在,
a
a
a 形状为
2
×
3
2\times3
2×3,
b
b
b 形状为
1
×
3
1\times3
1×3,二者形状不同,需要为
b
b
b 添加新行(复制第0行),把
b
b
b 的行数也变为2,使得各维的长度也相等。
这里解释一下维度的含义。在计算机中提到维度时,是指描述一个分量需要几个下标,前面(1)中的维度就是如此。在数学中提到维度时,指的是分量的个数,实际上也是各维的长度,如(2)中的维度用法。下面构造
a
、
b
a、b
a、b两个数组如下,
a
=
[
0
1
]
,
b
=
[
0
1
2
]
a=\left[\begin{matrix}0\\1\end{matrix}\right], \ \ \ b=\left[\begin{matrix}0&1&2\end{matrix}\right]
a=[01], b=[012]
观察广播的效果
a = torch.tensor([[0],[1]])
print(a)
b = torch.tensor([0,1,2])
print(b)
c = torch.broadcast_tensors(a, b)
print(c)
结果
tensor([[0],
[1]])
tensor([0, 1, 2])
(tensor([[0, 0, 0],
[1, 1, 1]]),
tensor([[0, 1, 2],
[0, 1, 2]]))
结果
c
c
c 有两个数组,分别是对
a
、
b
a、b
a、b 广播的结果。
急转弯,请写出下列代码的执行结果:
a = torch.tensor([[1,2]])
b = torch.tensor([[3],[4]])
c = a*b
分析:* 表示点乘运算,即相同形状的两个矩阵的对应元素相乘的运算。
a
a
a 的形状为 (1,2),
b
b
b 的形状为 (2,1),
需要用广播统一二者的形状,使得二者的形状都是 (2,2),
即
a
a
a 要多一行,
b
b
b 要多一列。通过拷贝,
a
a
a 成为 [[1,2], [1,2]],
b
b
b 成为 [[3,3], [4,4]],完成广播后,现在可以做对应相乘的操作了,结果为
tensor([[3, 6],
[4, 8]])
3、应用
下面来作一道习题(选自numpy习题100道中的第93题)
93. Consider two arrays A and B of shape (8,3)
and (2,2). How to find rows of A that contain
elements of each row of B regardless of the
order of the elements in B? (★★★)
例如,A、B分别是如下数组
A = torch.Tensor([[4, 9, 3],
[0, 3, 9],
[7, 3, 7],
[3, 1, 6],
[6, 9, 8],
[6, 6, 8],
[4, 3, 6],
[9, 1, 4]])
B = torch.Tensor([[4, 1],
[9, 9]])
A数组中的第0行[4,9,3]含有B数组两行中的元素、第7行[9,1,4]也含有B数组两行中的元素,所以结果为[0, 7]。
torch.manual_seed(0)
A = torch.randint(10,size=(8,3)).reshape(8,3,1,1)
B = torch.randint(10,size=(2,2))
C = A==B
rows = torch.where(C.any(3).any(1).all(1))[0]
print(rows)
代码运行结果
tensor([0, 7])
代码中,
torch.manual_seed(0) 指定随机数的种子,用于再次产生相同的随机数。A、B是题目要求形状的数组,且A的形状被调整为四维数组。A==B要求按位置比较A、B的对应元素是否相等,由于A、B的形状不同,所以系统用广播将A、B的形状调整为(8,3,2,2)。A中的每个元素被重复4遍,形成(2,2)的数组,分别与B的四个元素相比较,比较结果是逻辑值存于C数组中。C.any(3).any(1).all(1)的理解稍有难度。
先考虑更直观的问题。问题1,如何计算B数组中有没有1?
print(B==1)
tensor([[False, True],
[False, False]])
只需要用any来检测所有元素中是否有True即可。
print((B==1).any())
tensor(True)
继续,问题2,如何计算B数组的两行中都有1吗?any(1)的参数1表示方向,与sum是一样的。
print((B==1).any(1))
tensor([ True, False])
Ture表示第0行有1,False表示第1行没有1,两行都有1要求两个结果都是True,用all做判断如下,
print((B==1).any(1).all())
tensor(False)
结果与事实相符。
问题3,如何计算三个数中既有B的0行中的元素,也有B的1行中的元素?考虑A[0]中的三个数4、9、3。
print((A[0]==B))
tensor([[[ True, False], #B中有4吗?
[False, False]],
[[False, False], #B中有9吗?
[ True, True]],
[[False, False], #B中有3吗?
[False, False]]])
print((A[0]==B).any(2))
any(2)的参数是哪一维?这涉及到如何想象四维数组,后面会总结出来,这倒未必是紧要的知识点。any(2)就是上面的6横排,所以得到下面的6个判断结果(每排中有True,结果就为True)分别表示4是B的第0行,但不是第1行元素,9不是B的第0行,但是第1行元素,3不是B的第0行、也不是第1行元素。
tensor([[ True, False],
[False, True],
[False, False]])
这个结果中,左边一列表示有没有B的0列的元素,右边一列表示有没有B的1列的元素,用ayn(0)判断如下
print((A[0]==B).any(2).any(0))
tensor([True, True])
上面的结果表明,4、9、3中有B的0行元素,也有1行元素,再加个all就判断出是否有两行元素(两个都是True)。
print((A[0]==B).any(2).any(0).all())
tensor(True)
问题4,8个元素又如何?
(A==B).any(3).any(1).all(1)
结果中的True代表A的某行满足要求,False则表示不满足要求。
tensor([ True, False, False, False, False, False, False, True])
加个where是得出满足条件的行的下标,
torch.where(C.any(3).any(1).all(1))
(tensor([0, 7]),)
可见结果是个元组,再放上最后面的[0]就从元组中析出了结果。