Broadcasting伪代码
Inputs: array A with m dimensions; array B with n dimensions
p = max(m, n)
if m < p:
left-pad A's shape with 1s until it also has p dimensions
else if n < p:
left-pad B's shape with 1s until is also has p dimensions
result_dims = new list with p elements
for i in p-1 ... 0:
A_dim_i = A.shape[i]
B_dim_i = B.shape[i]
if A_dim_i != 1 and B_dim_i != 1 and A_dim_i != B_dim_i:
raise ValueError("could not broadcast")
else:
result_dims[i] = max(A_dim_i, B_dim_i)
示例
# PyTorch operations support NumPy Broadcasting Semantics.
x=torch.ones(4,1,2,)
y=torch.ones(6,1,)
print((x+y).size())
# 小的shape变成和大shape维度一样
# x 4 1 2
# y _4_ 6 1
# 从后往前
# 2 1 4
# 1 6 4
# 2 6 4
# 所以最后
# 4 6 2
torch.Size([4, 6, 2])