如果两个数组的后缘维度,即从末尾开始算起的维度的轴长度相符,或其中的一方的长度为1,则认为它们是广播兼容的。
1.后缘维度的轴长是相同的。
eg1:
x=torch.arange(6).reshape(2,3)
y=torch.tensor([1,2,3])
y.shape,x.shape,x,y,x+y
-------------------------------
(torch.Size([3]),
torch.Size([2, 3]),
tensor([[0, 1, 2],
[3, 4, 5]]),
tensor([1, 2, 3]),
tensor([[1, 3, 5],
[4, 6, 8]]))
#缺省也是没问题的
eg2:
x=torch.arange(6).reshape(2,3)
y=torch.tensor([[1,2,3]])
y.shape,x.shape,x,y,x+y
-------------------------------
(torch.Size([1, 3]),
torch.Size([2, 3]),
tensor([[0, 1, 2],
[3, 4, 5]]),
tensor([[1, 2, 3]]),
tensor([[1, 3, 5],
[4, 6, 8]]))
#其实可以看成第二种情况
eg3:
x=torch.arange(12).reshape(2,2,3)
y=torch.tensor([1,2,3])
y.shape,x.shape,x,y,x+y
-------------------------------
(torch.Size([3]),
torch.Size([2, 2, 3]),
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]]),
tensor([1, 2, 3]),
tensor([[[ 1, 3, 5],
[ 4, 6, 8]],
[[ 7, 9, 11],
[10, 12, 14]]]))
eg4:
x=torch.arange(12).reshape(2,2,3)
y=torch.tensor([[1,2,3],[1,2,3]])
y.shape,x.shape,x,y,x+y
-------------------------------
(torch.Size([2, 3]),
torch.Size([2, 2, 3]),
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]]),
tensor([[1, 2, 3],
[1, 2, 3]]),
tensor([[[ 1, 3, 5],
[ 4, 6, 8]],
[[ 7, 9, 11],
[10, 12, 14]]]))
2.两个数组维数相同,有一方的长度为1。
eg1:
x=torch.arange(12).reshape(4,3)
y=torch.arange(4).reshape(4,1)
x.shape,y.shape,x,y,x+y
-------------------------------
(torch.Size([4, 3]),
torch.Size([4, 1]),
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]),
tensor([[0],
[1],
[2],
[3]]),
tensor([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10],
[12, 13, 14]]))
eg2:
x=torch.arange(18).reshape(3,1,6)
y=torch.arange(15).reshape(3,5,1)
x.shape,y.shape,x,y,x+y
-------------------------------
(torch.Size([3, 1, 6]),
torch.Size([3, 5, 1]),
tensor([[[ 0, 1, 2, 3, 4, 5]],
[[ 6, 7, 8, 9, 10, 11]],
[[12, 13, 14, 15, 16, 17]]]),
tensor([[[ 0],
[ 1],
[ 2],
[ 3],
[ 4]],
[[ 5],
[ 6],
[ 7],
[ 8],
[ 9]],
[[10],
[11],
[12],
[13],
[14]]]),
tensor([[[ 0, 1, 2, 3, 4, 5],
[ 1, 2, 3, 4, 5, 6],
[ 2, 3, 4, 5, 6, 7],
[ 3, 4, 5, 6, 7, 8],
[ 4, 5, 6, 7, 8, 9]],
[[11, 12, 13, 14, 15, 16],
[12, 13, 14, 15, 16, 17],
[13, 14, 15, 16, 17, 18],
[14, 15, 16, 17, 18, 19],
[15, 16, 17, 18, 19, 20]],
[[22, 23, 24, 25, 26, 27],
[23, 24, 25, 26, 27, 28],
[24, 25, 26, 27, 28, 29],
[25, 26, 27, 28, 29, 30],
[26, 27, 28, 29, 30, 31]]]))
注意.从以上两种情况可以得出torch.Size([3, 5, 6]),
torch.Size([1, 6]),也是没问题的。
x=torch.arange(90).reshape(3,5,6)
y=torch.arange(6).reshape(1,6)
x.shape,y.shape,x,y,x+y
-------------------------------
(torch.Size([3, 5, 6]),
torch.Size([1, 6]),
tensor([[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29]],
[[30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41],
[42, 43, 44, 45, 46, 47],
[48, 49, 50, 51, 52, 53],
[54, 55, 56, 57, 58, 59]],
[[60, 61, 62, 63, 64, 65],
[66, 67, 68, 69, 70, 71],
[72, 73, 74, 75, 76, 77],
[78, 79, 80, 81, 82, 83],
[84, 85, 86, 87, 88, 89]]]),
tensor([[0, 1, 2, 3, 4, 5]]),
tensor([[[ 0, 2, 4, 6, 8, 10],
[ 6, 8, 10, 12, 14, 16],
[12, 14, 16, 18, 20, 22],
[18, 20, 22, 24, 26, 28],
[24, 26, 28, 30, 32, 34]],
[[30, 32, 34, 36, 38, 40],
[36, 38, 40, 42, 44, 46],
[42, 44, 46, 48, 50, 52],
[48, 50, 52, 54, 56, 58],
[54, 56, 58, 60, 62, 64]],
[[60, 62, 64, 66, 68, 70],
[66, 68, 70, 72, 74, 76],
[72, 74, 76, 78, 80, 82],
[78, 80, 82, 84, 86, 88],
[84, 86, 88, 90, 92, 94]]]))