1.什么是广播机制
根据线性代数的运算规则我们知道,矩阵运算往往都是在两个矩阵维度相同或者相匹配时才能运算。比如加减法需要两个矩阵的维度相同,乘法需要前一个矩阵的列数与后一个矩阵的行数相等。那么在 numpy、tensor 里也是同样的道理,但是在机器学习的某些算法中会出现两个维度不相同也不匹配的矩阵进行运算,那么这时候就需要用广播机制来解决,通过广播机制,其tensor参数可以自动扩展为相等大小(不需要复制数据)。下面我们以tensor为例来解释什么是广播机制。
2.广播机制的规则
先来说下广播机制的规则,只有遵循下面的两个规则,两个张量才可以进行广播运算。
- 每个tensor至少有一个维度;
- 遍历tensor所有维度,注意,是从末尾开始遍历(从右往左开始遍历),两个tensor存在下列情况:
1. 这两个维度的大小相等
2. 某个维度 一个张量有,一个张量没有
3. 某个维度 一个张量有,一个张量也有但大小是1
2.1 广播机制的理解
满足上面的条件才可以进行广播机制,简言之,广播机制的目标是把两个维度不相等的张量变成一样的维度。举个例子来看下,我们以两个张量之间的相加为例演示,简单看下两个张量广播机制之后shape是什么,其他运算原理相同:
- 满足上面的广播机制条件,执行A.shape==B.shape;
- 形状对齐之后,对应位置进行相加运算,执行结果的shape:A.shape和B.shape对应位置的最大值,比如:A.shape=(1,2,3),B.shape=(3,1,3),那么A+B的shape是(3,2,3)。
看到这个例子可能还会有小伙伴不明白,这shape怎么变得,为什么最后是(3,2,3)呢,我怎么看不出来呢,别急,我们下面讲一下原理,很简单,看完例子你就明白了。下面我们针对上面的两个规则进行逐一举例。
3.代码举例
3.1 规则一
- 每个tensor至少有一个维度
import torch
a = torch.tensor([1,2,3])
b = torch.tensor([3])
c = a*b
print(a.shape)
print(b.shape)
print(c.shape)
print(c)
输出结果如下:
torch.Size([3])
torch.Size([1])
torch.Size([3])
tensor([3, 6, 9])
3.2 规则二
- 这两个维度的大小相等
- 某个维度 一个张量有,一个张量没有
- 某个维度 一个张量有,一个张量也有但大小是1
先看一个综合的例子:
import torch
a = torch.randint(0, 4, (4, 3, 2, 1)) # 元素值在 0 到 3 之间,张量形状为 (3, 2, 1)
b = torch.randint(0, 3, (3, 1, 1)) # 元素值在 0 到 2 之间,张量形状为 (1, 1)
c = a * b
print(a.shape)
print(b.shape)
print(c.shape)
print(c)
输出结果:
torch.Size([4, 3, 2, 1])
torch.Size([3, 1, 1])
torch.Size([4, 3, 2, 1])
tensor([[[[0],
[0]],
[[0],
[2]],
[[0],
[0]]],
[[[0],
[0]],
[[4],
[2]],
[[0],
[0]]],
[[[0],
[0]],
[[4],
[0]],
[[0],
[0]]],
[[[0],
[0]],
[[4],
[0]],
[[0],
[0]]]])
如上面代码所示,首先将两个张量维度从右开始对齐,两个张量最后一维大小相等,都为1,满足上面条件1;倒数第二个维度大小不相等,但第二个张量倒数第二维大小为1,满足上面条件3;倒数第三个维度大小相等都为3,满足上面条件1;第一个张量第一维度有,第二个张量没有,满足上面条件2,因此两个张量每个维度都符合上面广播条件,因此可以进行广播。
3.21 有一个张量缺少维度,一定可以进行 broadcast:
import torch
x = torch.rand(1, 2, 3, 4)
y = torch.rand(2, 3, 4)
print(x.shape)
print(y.shape)
z = x + y
print(z.shape)
print(x)
print(y)
print(z)
输出结果:
torch.Size([1, 2, 3, 4])
torch.Size([2, 3, 4])
torch.Size([1, 2, 3, 4])
tensor([[[[0.0094, 0.1863, 0.2657, 0.3782],
[0.3296, 0.7454, 0.2080, 0.4156],
[0.2092, 0.5414, 0.1053, 0.3872]],
[[0.8161, 0.3554, 0.7352, 0.2116],
[0.7459, 0.1662, 0.7555, 0.4548],
[0.2611, 0.0353, 0.1862, 0.5948]]]])
tensor([[[0.4637, 0.3938, 0.2039, 0.3892],
[0.4146, 0.8713, 0.3947, 0.5345],
[0.2401, 0.3800, 0.3747, 0.8381]],
[[0.0459, 0.1242, 0.3529, 0.1527],
[0.2361, 0.2850, 0.8671, 0.8040],
[0.6575, 0.4075, 0.8156, 0.2638]]])
tensor([[[[0.4730, 0.5801, 0.4695, 0.7674],
[0.7442, 1.6167, 0.6027, 0.9501],
[0.4493, 0.9214, 0.4800, 1.2253]],
[[0.8620, 0.4796, 1.0881, 0.3643],
[0.9820, 0.4512, 1.6227, 1.2588],
[0.9186, 0.4428, 1.0018, 0.8586]]]])
上面的张量y跟张量x相比缺少一个维度,根据广播机制的规则我们从最后一个维度进行匹配,后面三个维度都一样,张量y的缺少一个维度,于是触发广播机制。
3.22 两个张量的维度不相等,其中有一个张量的对应维度为1或者缺失,一定可以进行 broadcast:
import torch
x = torch.rand(1, 2, 3, 4)
y = torch.rand(2, 1, 1)
print(x.shape)
print(y.shape)
z = x + y
print(z.shape)
print(x)
print(y)
print(z)
输出结果:
torch.Size([1, 2, 3, 4])
torch.Size([2, 1, 1])
torch.Size([1, 2, 3, 4])
tensor([[[[0.8670, 0.0134, 0.7929, 0.4109],
[0.3595, 0.8457, 0.2819, 0.8470],
[0.5040, 0.9281, 0.9161, 0.7305]],
[[0.3798, 0.3866, 0.4680, 0.5744],
[0.6984, 0.6501, 0.2235, 0.3099],
[0.9861, 0.8598, 0.7635, 0.3238]]]])
tensor([[[0.3393]],
[[0.1775]]])
tensor([[[[1.2062, 0.3527, 1.1322, 0.7501],
[0.6987, 1.1850, 0.6212, 1.1863],
[0.8433, 1.2674, 1.2554, 1.0698]],
[[0.5574, 0.5641, 0.6455, 0.7519],
[0.8759, 0.8276, 0.4010, 0.4875],
[1.1636, 1.0373, 0.9410, 0.5013]]]])
以上就是广播机制的操作,只要记住几个规则就行了,注意tensor在进行运算的时候是从后往前匹配运算的。
3.3. 总结
针对以上的分析,其实我们可以把广播机制总结如下:两个shape完全相同的张量肯定能广播机制或者说不需要广播机制。张量shape不同的时候,从后往前看,维度不是1的时候维度数量必须相等,除非匹配到一个张量已经没有这个维度了(这个时候会自动增加维度),不然没法广播机制。维度匹配过程中数字以最大的为准。
4. 原地操作
补充一点,在进行广播机制的时候我们要注意一个原地操作运算,什么是原地操作运算?原地操作运算就是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值。在pytorch中经常加后缀“”来代表原地操作符,例:.add _()
、.scatter()
,原地操作不允许tensor
使用广播机制那样来改变张量形状维度大小,如下例子所示。
import torch
x = torch.rand(1,3,1)
y = torch.rand(3,1,7)
print(x.shape)
print(y.shape)
z = x.add_(y)
print(z.shape)
print(x)
print(y)
print(z)
输出结果:
torch.Size([1, 3, 1])
torch.Size([3, 1, 7])
Traceback (most recent call last):
File "D:/program/Test/broadcast/test.py", line 8, in <module>
z = x.add_(y)
RuntimeError: output with shape [1, 3, 1] doesn't match the broadcast shape [3, 3, 7]