许多PyTorch操作都支持。NumPy Broadcasting Semantics
简而言之,如果PyTorch操作支持广播,那么它的Tensor参数可以自动扩展为相同的大小(不需要复制数据)。
一般语义
如果遵守以下规则,则两个张量是“可播放的”:
- 每个张量至少有一个维度。
- 迭代尺寸大小时,从尾随尺寸开始,尺寸大小必须相等,其中一个为1,或者其中一个不存在。
例如:
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension
# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
如果两个张量x
,y
是“broadcastable”,所得到的张量大小的计算方法如下:
- 如果尺寸的数量
x
和y
不相等,则在尺寸较小的张量的前面加1,使它们的长度相等。 - 然后,对于每个维度大小,生成的维度大小是该维度的大小
x
和y
沿该维度的最大值 。
例如:
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
就地语义
一个复杂因素是就地操作不允许就地张量由于广播而改变形状。
例如:
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])
# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
向后兼容性
PyTorch的早期版本允许某些逐点函数在具有不同形状的张量上执行,只要每个张量中的元素数量相等即可。然后通过将每个张量视为1维来执行逐点运算。PyTorch现在支持广播,并且“1维”逐点行为被认为已弃用,并且在张量不可播放但具有相同数量的元素的情况下将生成Python警告。
注意,在两个张量不具有相同形状但是可广播并且具有相同数量的元素的情况下,广播的引入可能导致向后不兼容的改变。例如:
>>> torch.add(torch.ones(4,1), torch.randn(4))
之前会产生一个尺寸为Tensor的尺寸:torch.Size([4,1]),但现在产生尺寸为Tensor:torch.Size([4,4])。为了帮助识别代码中可能存在广播引起的向后不兼容性的情况,您可以将torch.utils.backcompat.broadcast_warning.enabled设置为True,这将在这种情况下生成python警告。
例如:
>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.
Many PyTorch operations support NumPy Broadcasting Semantics
.
In short, if a PyTorch operation supports broadcast, then its Tensor arguments can be automatically expanded to be of equal sizes (without making copies of the data).
General semantics
Two tensors are “broadcastable” if the following rules hold:
- Each tensor has at least one dimension.
- When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
For Example:
>>> x=torch.empty(5,7,3) >>> y=torch.empty(5,7,3) # same shapes are always broadcastable (i.e. the above rules always hold) >>> x=torch.empty((0,)) >>> y=torch.empty(2,2) # x and y are not broadcastable, because x does not have at least 1 dimension # can line up trailing dimensions >>> x=torch.empty(5,3,4,1) >>> y=torch.empty( 3,1,1) # x and y are broadcastable. # 1st trailing dimension: both have size 1 # 2nd trailing dimension: y has size 1 # 3rd trailing dimension: x size == y size # 4th trailing dimension: y dimension doesn't exist # but: >>> x=torch.empty(5,2,4,1) >>> y=torch.empty( 3,1,1) # x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
If two tensors x
, y
are “broadcastable”, the resulting tensor size is calculated as follows:
- If the number of dimensions of
x
andy
are not equal, prepend 1 to the dimensions of the tensor with fewer dimensions to make them equal length. - Then, for each dimension size, the resulting dimension size is the max of the sizes of
x
andy
along that dimension.
For Example:
# can line up trailing dimensions to make reading easier >>> x=torch.empty(5,1,4,1) >>> y=torch.empty( 3,1,1) >>> (x+y).size() torch.Size([5, 3, 4, 1]) # but not necessary: >>> x=torch.empty(1) >>> y=torch.empty(3,1,7) >>> (x+y).size() torch.Size([3, 1, 7]) >>> x=torch.empty(5,2,4,1) >>> y=torch.empty(3,1,1) >>> (x+y).size() RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
In-place semantics
One complication is that in-place operations do not allow the in-place tensor to change shape as a result of the broadcast.
For Example:
>>> x=torch.empty(5,3,4,1) >>> y=torch.empty(3,1,1) >>> (x.add_(y)).size() torch.Size([5, 3, 4, 1]) # but: >>> x=torch.empty(1,3,1) >>> y=torch.empty(3,1,7) >>> (x.add_(y)).size() RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
Backwards compatibility
Prior versions of PyTorch allowed certain pointwise functions to execute on tensors with different shapes, as long as the number of elements in each tensor was equal. The pointwise operation would then be carried out by viewing each tensor as 1-dimensional. PyTorch now supports broadcasting and the “1-dimensional” pointwise behavior is considered deprecated and will generate a Python warning in cases where tensors are not broadcastable, but have the same number of elements.
Note that the introduction of broadcasting can cause backwards incompatible changes in the case where two tensors do not have the same shape, but are broadcastable and have the same number of elements. For Example:
>>> torch.add(torch.ones(4,1), torch.randn(4))
would previously produce a Tensor with size: torch.Size([4,1]), but now produces a Tensor with size: torch.Size([4,4]). In order to help identify cases in your code where backwards incompatibilities introduced by broadcasting may exist, you may set torch.utils.backcompat.broadcast_warning.enabled to True, which will generate a python warning in such cases.
For Example:
>>> torch.utils.backcompat.broadcast_warning.enabled=True >>> torch.add(torch.ones(4,1), torch.ones(4)) __main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements. Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.