Broadcasting机制

一、Broadcasting机制介绍

  1. Broadcasting(广播)机制是numpy对不同shape的数组进行数值计算的一种方式。
  2. Numpy的通用函数中要求输入的两个tensor的shape一致。当两个tensor的shape不一致时,会使用Broadcasting机制。

二、Broadcast机制规则

  1. 规则1:让所有输入的tensor向dim最大的tensor看齐,dim不足,前补1
  2. 规则2:将输入的tensor的某一个shape为1的维度拉伸
  3. 规则3:如果存在两个tensor的shape在任何维度均不匹配,且均没有等于1的维度,则会报错

三、实例说明

  1. Broadcasting理解
    1. 三部分代码分别对应下图的三个实例
a = torch.tensor([ [0,0,0],[10,10,10],[20,20,20],[30,30,30]]) # torch.Size([4, 3])
b = torch.tensor([ [0,1,2] , [0,1,2],[0,1,2],[0,1,2]]) # torch.Size([4, 3])
print("a + b == \n",a + b)
# 由于a和b的shape相同,不需要使用Broadcasting机制
# ---------------------------------------------------------------------
a = torch.tensor([ [0,0,0],[10,10,10],[20,20,20],[30,30,30]]) # torch.Size([4, 3])
b = torch.tensor([0,1,2]) # torch.Size([3])
print("a + b == \n",a + b)
# a.shape = [4,3] a.dim = 2; b.shape = [3], b.dim = 1, shape不同,需要使用Broadcasting机制
# 由于b.dim = 1 < a.dim = 2 , b向a看齐
# 相当于b.unsqueeze(0),使其shape变为[1,3],然后再使用expand(4,-1)使其shape变为[4,3]

# ---------------------------------------------------------------------
a = torch.tensor([[0],[10],[20],[30]]) # torch.Size([4, 1])
b = torch.tensor([0,1,2]) # torch.Size([3])
print("a + b == \n",a + b)
# a.shape = [4,1] a.dim = 2; b.shape = [3], b.dim = 1,  shape不同,需要使用Broadcasting机制
# 由于b.dim = 1 < a.dim = 2 , b向a看齐
# 如果dim不同需要补齐所以b的shape变换为 [3] => [1,3]
# dim变换后,将shape为1的拉伸: a.shape变换为: [4,1] => [4,3] ; b.shape变换为: [1,3] => [4,3]

在这里插入图片描述
2. 简单实际应用
1. 给你一个tensor,其shape含义为[class,students,scores]
2. 先需要将学生的所有成绩+5分

a = torch.rand(4,32,8) # torch.Size([4, 32, 8])
b = torch.tensor(5) # torch.Size([])
c = a + b
  1. 错误实例
a = torch.rand(4,32,14,14)
b = torch.rand(2,32,14,14)
c = a + b
# RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0
# 这里出错原因: 
# Dim 0 has dim,can NOT insert and expand to same
# Dim 0 has distinct dim, NOT size 1
# NOT broadcasting-able
  1. 多种方式理解
    1. 一下例子我们用[b,c,h,w]
    2. [4,3,32,32] + [32,32] : 对于所有的batch,channels,都叠加相同的picture_map,相当于叠加相同的base,使其平移。
    3. [4,3,32,32] + [3,1,1] :对于RGB来说,增加值
    4. [4,3,32,32] + [1,1,1,1] :对于所有像素点都增加一个值
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值