广播是PyTorch以及Numpy中的一个重要机制,很多人在学习PyTorch以及Numpy的时候往往对广播机制一掠而过,最终只是略懂皮毛,然而在实际中,如果不能全面的了解广播机制,可能很多可以并行执行的操作自己在实现时往往叠加了多层for循环,看别人的源码往往百思不得其解,本人也受过对广播机制一知半解的毒害。本文的目的在于全面的掌握广播机制到底如何执行,至于广播机制是什么之类的概念则不涉及。
可以执行广播机制的情况
- 两个数组的对应维度相同
- 两个数组对应维度的较低维维度为1
- 两个数组中其中一个数组缺少维度
两个数组对应维度相同的情况
a = torch.tensor((5, 7, 9))
b = torch.tensor((5, 7, 9))
# a和b的形状完全相同,此时可以直接执行操作
# 其实这种情况不需要广播,但既然官方文档中提到这种情况,那就提出来
c = a + b
两个数组对应维度的较低维度为1的情况
a = torch.tensor((5, 7, 9))
b = torch.tensor((5, 1, 1))
# b的后两维都为1,因此可以执行广播操作,广播后的结果是(5, 7, 9),即两者对应维度处维度最大的值
a = a + b # 广播后的结果形状为(5, 7, 9), 即两者对应维度处维度最大的值
c = torch.tensor((5, 1, 9))
d = torch.tensor((5, 7, 1))
c = c + d # 广播后的结果形状为(5, 7, 9), 即两者对应维度处维度最大的值
e = torch.tensor((5, 7, 9))
f = torch.tensor((5, 7, 3))
# 相对应的,此处的操作是不合法的,因为b的最后一个维度不是1,而是3,即使9是3的倍数,也无法执行广播
e = e + f # 此处操作不合法
# 以下是很少注意到的特殊情况
c = torch.tensor((1, 7, 9))
c1 = torch.tensor((5, 7, 9))
d = torch.tensor((0, 7, 9))
# 此处也是合法的,0是一个特殊的存在,该维度对应的数组维度必须为1才能够实现广播,即1维会广播到0维
c = c + d
# 下面操作不合法,维度为5无法广播到维度为0,当存在0维度时,所有维度都是广播到0维的一方
c1 = c1 + d
e = torch.tensor((5, 7, 9))
f = torch.tensor((5, 7, 0))
# 此处操作不合法,因为e的最后一维不为1,0在中间维度也是相同原则
e = e + f
其中一个数组缺失维度的情况
a = torch.tensor((5, 7, 9))
b = torch.tensor((7, 9))
# 对于这种情况,首先从两个数组的末位维度开始,进行广播,此例中均为9,
# 其次从次末位开始,此处均为7,依#次类推,但发现b的相对于a缺失一个
# 维度,此时,首先b会在首位新增一个维度(1,7,9),随后再执行广播
# 操作。不止于这种情况,广播机制的维度匹配顺序均是上述所说。
a = a + b # 合法,输出结果形状为(5,7,9)