PyTorch以及Numpy中的广播机制

广播是PyTorch以及Numpy中的一个重要机制,很多人在学习PyTorch以及Numpy的时候往往对广播机制一掠而过,最终只是略懂皮毛,然而在实际中,如果不能全面的了解广播机制,可能很多可以并行执行的操作自己在实现时往往叠加了多层for循环,看别人的源码往往百思不得其解,本人也受过对广播机制一知半解的毒害。本文的目的在于全面的掌握广播机制到底如何执行,至于广播机制是什么之类的概念则不涉及。

可以执行广播机制的情况
  1. 两个数组的对应维度相同
  2. 两个数组对应维度的较低维维度为1
  3. 两个数组中其中一个数组缺少维度
 两个数组对应维度相同的情况
a = torch.tensor((5, 7, 9))
b = torch.tensor((5, 7, 9))
# a和b的形状完全相同,此时可以直接执行操作
# 其实这种情况不需要广播,但既然官方文档中提到这种情况,那就提出来
c = a + b
两个数组对应维度的较低维度为1的情况
a = torch.tensor((5, 7, 9))
b = torch.tensor((5, 1, 1))
# b的后两维都为1,因此可以执行广播操作,广播后的结果是(5, 7, 9),即两者对应维度处维度最大的值
a = a + b  # 广播后的结果形状为(5, 7, 9), 即两者对应维度处维度最大的值

c = torch.tensor((5, 1, 9))
d = torch.tensor((5, 7, 1))
c = c + d  # 广播后的结果形状为(5, 7, 9), 即两者对应维度处维度最大的值

e = torch.tensor((5, 7, 9))
f = torch.tensor((5, 7, 3))
# 相对应的,此处的操作是不合法的,因为b的最后一个维度不是1,而是3,即使9是3的倍数,也无法执行广播
e = e + f  # 此处操作不合法

# 以下是很少注意到的特殊情况

c = torch.tensor((1, 7, 9))
c1 = torch.tensor((5, 7, 9))
d = torch.tensor((0, 7, 9))
# 此处也是合法的,0是一个特殊的存在,该维度对应的数组维度必须为1才能够实现广播,即1维会广播到0维
c = c + d  
# 下面操作不合法,维度为5无法广播到维度为0,当存在0维度时,所有维度都是广播到0维的一方
c1 = c1 + d

e = torch.tensor((5, 7, 9))
f = torch.tensor((5, 7, 0))
# 此处操作不合法,因为e的最后一维不为1,0在中间维度也是相同原则
e = e + f  
其中一个数组缺失维度的情况
a = torch.tensor((5, 7, 9))
b = torch.tensor((7, 9))
# 对于这种情况,首先从两个数组的末位维度开始,进行广播,此例中均为9,
# 其次从次末位开始,此处均为7,依#次类推,但发现b的相对于a缺失一个
# 维度,此时,首先b会在首位新增一个维度(1,7,9),随后再执行广播
# 操作。不止于这种情况,广播机制的维度匹配顺序均是上述所说。
a = a + b  # 合法,输出结果形状为(5,7,9)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值