Pytorch中的广播机制

本文深入探讨了PyTorch中张量的广播条件,包括张量维度匹配的要求,以及如何进行张量广播。通过多个实例展示了张量在运算时如何通过广播机制扩展维度以进行有效计算,并解释了广播后的张量大小计算规则。同时,通过多个示例演示了不同形状张量相加等元素级运算的具体过程,帮助理解广播在实际运算中的应用。
摘要由CSDN通过智能技术生成

广播条件

两个张量只有都满足下面两个条件,才可以广播:

  1. 每个张量都至少有一个维度
  2. 对两个张量的维度从后往前(从右向左) 处理,维度的大小(这个维度的长度)必须要么相等要么其中一个为1,或者其中一个张量后面不存在维度了

例:

>>>import torch
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
'''相同的形状总是可以广播的'''

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
'''不能广播,因为两个张量都必须只有一个维度'''

'''可以将尾部对齐(can line up trailing dimensions)'''
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
'''
x和y可以广播
# 倒数第一个维度:x size == y size == 1
# 倒数第二个维度: y has size 1
# 倒数第三个维度 == y size
# 倒数第四个维度: y后面不再有维度了
——————英文原文如下,表达语句值得学习——————
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
'''

>>> x=torch.empty(3,2,4,1)
>>> y=torch.empty(  3,1,1)
'''x和y不能广播,因为倒数第三个维度大小不同,且不为1'''

———————————————————————————

运算法则

如果两个张量x, y是可广播的,结果的张量大小按如下方式计算:

  1. 如果x和y的维度数量不同,对维度数量少的张量增加新的维度,且维度大小为1,使得两个张量的维度数量相同
  2. 对每个维度,结果的维度大小是x和y的维度大小的最大值。(其实如果某个维度大小不同,那么有一个维度大小肯定是1)

例1:

import torch
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

上面以维度的角度展示了运算法则。下面从元素级运算,展示具体操作原理:

例2:

import torch

x = torch.arange(0,24).reshape(2,4,3)
print(x)
y = torch.arange(0,3).reshape(1,3)
print(y)
z = x + y
print(z)
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],
         
        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]])
         
tensor([[0, 1, 2]])

tensor([[[ 0,  2,  4],
         [ 3,  5,  7],
         [ 6,  8, 10],
         [ 9, 11, 13]],
         
        [[12, 14, 16],
         [15, 17, 19],
         [18, 20, 22],
         [21, 23, 25]]])

例3:

x = torch.arange(0,24).reshape(2,4,3)
print(x)
y = torch.arange(0,4).reshape(4,1)
print(y)
z = x + y
print(z)
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],
         
        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]])
         
tensor([[0],
        [1],
        [2],
        [3]])
        
tensor([[[ 0,  1,  2],
         [ 4,  5,  6],
         [ 8,  9, 10],
         [12, 13, 14]],
         
        [[12, 13, 14],
         [16, 17, 18],
         [20, 21, 22],
         [24, 25, 26]]])

例4:

x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(0,3).reshape(3,1)
print(y)
z = x + y
print(z)

结果为:

tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],
          
         [[ 6,  7],
          [ 8,  9],
          [10, 11]],
          
         [[12, 13],
          [14, 15],
          [16, 17]],
          
         [[18, 19],
          [20, 21],
          [22, 23]]],
          
        [[[24, 25],
          [26, 27],
          [28, 29]],
          
         [[30, 31],
          [32, 33],
          [34, 35]],
          
         [[36, 37],
          [38, 39],
          [40, 41]],
          
         [[42, 43],
          [44, 45],
          [46, 47]]]])
          
tensor([[0],
        [1],
        [2]])
        
tensor([[[[ 0,  1],
          [ 3,  4],
          [ 6,  7]],
          
         [[ 6,  7],
          [ 9, 10],
          [12, 13]],
          
         [[12, 13],
          [15, 16],
          [18, 19]],
          
         [[18, 19],
          [21, 22],
          [24, 25]]],
          
        [[[24, 25],
          [27, 28],
          [30, 31]],
          
         [[30, 31],
          [33, 34],
          [36, 37]],
          
         [[36, 37],
          [39, 40],
          [42, 43]],
          
         [[42, 43],
          [45, 46],
          [48, 49]]]])

例5:

x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(0,8).reshape(4,1,2)
print(y)
z = x + y
print(z)
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],
          
         [[ 6,  7],
          [ 8,  9],
          [10, 11]],
          
         [[12, 13],
          [14, 15],
          [16, 17]],
          
         [[18, 19],
          [20, 21],
          [22, 23]]],
          
        [[[24, 25],
          [26, 27],
          [28, 29]],
          
         [[30, 31],
          [32, 33],
          [34, 35]],
          
         [[36, 37],
          [38, 39],
          [40, 41]],
          
         [[42, 43],
          [44, 45],
          [46, 47]]]])
          
tensor([[[0, 1]],

        [[2, 3]],
        
        [[4, 5]],
                        
        [[6, 7]]])
        
tensor([[[[ 0,  2],
          [ 2,  4],
          [ 4,  6]],
          
         [[ 8, 10],
          [10, 12],
          [12, 14]],
          
         [[16, 18],
          [18, 20],
          [20, 22]],
          
         [[24, 26],
          [26, 28],
          [28, 30]]],
          
        [[[24, 26],
          [26, 28],
          [28, 30]],
          
         [[32, 34],
          [34, 36],
          [36, 38]],
          
         [[40, 42],
          [42, 44],
          [44, 46]],
          
         [[48, 50],
          [50, 52],
          [52, 54]]]])

例6:

x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(2,4).reshape(1,1,2)
print(y)
z = x + y
print(z)
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],
          
         [[ 6,  7],
          [ 8,  9],
          [10, 11]],
          
         [[12, 13],
          [14, 15],
          [16, 17]],
          
         [[18, 19],         
          [20, 21],
          [22, 23]]],
          
        [[[24, 25],
          [26, 27],
          [28, 29]],
          
         [[30, 31],
          [32, 33],
          [34, 35]],
          
         [[36, 37],
          [38, 39],
          [40, 41]],
          
         [[42, 43],
          [44, 45],
          [46, 47]]]])
          
tensor([[[2, 3]]])

tensor([[[[ 2,  4],
          [ 4,  6],
          [ 6,  8]],
          
         [[ 8, 10],
          [10, 12],
          [12, 14]],
          
         [[14, 16],
          [16, 18],
          [18, 20]],
          
         [[20, 22],
          [22, 24],
          [24, 26]]],
          
        [[[26, 28],
          [28, 30],
          [30, 32]],
          
         [[32, 34],
          [34, 36],
          [36, 38]],
          
         [[38, 40],
          [40, 42],
          [42, 44]],
          
         [[44, 46],
          [46, 48],
          [48, 50]]]])

例7:

x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(2,8).reshape(2,1,3,1)
print(y)
z = x + y
print(z)
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],
          
         [[ 6,  7],
          [ 8,  9],
          [10, 11]],
          
         [[12, 13],
          [14, 15],
          [16, 17]],
          
         [[18, 19],
          [20, 21],
          [22, 23]]],
          
        [[[24, 25],
          [26, 27],
          [28, 29]],
          
         [[30, 31],
          [32, 33],
          [34, 35]],
          
         [[36, 37],
          [38, 39],
          [40, 41]],
          
         [[42, 43],
          [44, 45],
          [46, 47]]]])
          
tensor([[[[2],
          [3],
          [4]]],
          
        [[[5],
          [6],
          [7]]]])
          
tensor([[[[ 2,  3],
          [ 5,  6],
          [ 8,  9]],
          
         [[ 8,  9],
          [11, 12],
          [14, 15]],
          
         [[14, 15],
          [17, 18],
          [20, 21]],
          
         [[20, 21],
          [23, 24],
          [26, 27]]],
          
        [[[29, 30],
          [32, 33],
          [35, 36]],
          
         [[35, 36],
          [38, 39],
          [41, 42]],
          
         [[41, 42],
          [44, 45],
          [47, 48]],
          
         [[47, 48],
          [50, 51],
          [53, 54]]]])

例8:

x = torch.arange(0,24).reshape(2,4,3,1)
print(x)
y = torch.arange(0,12).reshape(2,1,3,2)
print(y)
z = x + y
print(z)
tensor([[[[ 0],
          [ 1],
          [ 2]],
          
         [[ 3],
          [ 4],
          [ 5]],
          
         [[ 6],
          [ 7],
          [ 8]],
          
         [[ 9],
          [10],
          [11]]],
          
        [[[12],
          [13],
          [14]],
          
         [[15],
          [16],
          [17]],
          
         [[18],
          [19],
          [20]],
          
         [[21],
          [22],
          [23]]]])
          
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]]],
          
        [[[ 6,  7],
          [ 8,  9],
          [10, 11]]]])
          
tensor([[[[ 0,  1],
          [ 3,  4],
          [ 6,  7]],
         [[ 3,  4],
         
          [ 6,  7],
          [ 9, 10]],
          
         [[ 6,  7],         
          [ 9, 10],
          [12, 13]],
          
         [[ 9, 10],
          [12, 13],
          [15, 16]]],
          
        [[[18, 19],
          [21, 22],
          [24, 25]],
          
         [[21, 22],
          [24, 25],
          [27, 28]],
          
         [[24, 25],
          [27, 28],
          [30, 31]],
          
         [[27, 28],
          [30, 31],
          [33, 34]]]])
  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值