一般语义(做加减法,但是不影响源张量)
如果遵守以下规则,则两个张量是“可播放的”:
- 每个张量至少有一个维度。
- 两个张量维度不一致的时候,维度小的那个张量直接在最前面自动加一个维度,其余的维度必须保持相等或者其中有一个1
- 维度即便相等,在做加法减法的时候也需要保持所有维度要么相等要么其中存在一个1
- 最后生成的维度就是两个张量各自最大的那个维度
例如:
x = torch.empty(5, 3, 1, 1)
y = torch.empty(3, 3, 1)
print((x+y).size())
## torch.Size([5, 3, 3, 1])
首先y比x小一个维度,具体的处理过程如下:
1. y自动增加一个维度(1,3,3,1)
2. 比较维度并且选择更大的维度
5 ---- 1 ----> 5
3 ---- 1 ----> 3
1 ---- 3 ----> 3
1 ---- 1 ----> 1
3. 输出维度为5,3,3,1
x = torch.empty(5, 3, 4, 1)
y = torch.empty(3, 2, 1)
print((x+y).size())
#The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 2
首先y比x小一个维度,具体的处理过程如下:
1. y自动增加一个维度(1,3,2,1)
2. 比较维度并且选择更大的维度
5 ---- 1 ----> 5
3 ---- 3 ----> 3
4 ---- 2 ----> 不满足第二条约定,需要其他维度都相等或者存在一个1
1 ---- 1 ----> 1
3. 输出错误
就地语义(对源张量做处理)
不允许每个张量由于广播而改变形状。
x = torch.empty(1,3,1)
y = torch.empty(3,1,7)
print(x.add_(y).size())
# output with shape [1, 3, 1] doesn't match the broadcast shape [3, 3, 7]