BROADCASTING SEMANTICS
pytorch里的广播:
转自https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics
General semantics
Two tensors are “broadcastable” if the following rules hold:
-
Each tensor has at least one dimension.
-
When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
For Example:
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
If two tensors x
, y
are “broadcastable”, the resulting tensor size is calculated as follows:
-
If the number of dimensions of
x
andy
are not equal, prepend 1 to the dimensions of the tensor with fewer dimensions to make them equal length. -
Then, for each dimension size, the resulting dimension size is the max of the sizes of
x
andy
along that dimension.
For Example:
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1