在 PyTorch 中存在广播(Broadcasting),广播是一种机制,用于自动扩展较小的张量以匹配较大张量的形状,从而使得它们能够进行元素级操作(如加法、减法、乘法等)。广播并不改变张量的实际数据,而是通过虚拟扩展来简化操作。
广播机制的规则
-
如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。
-
如果两个张量在某个维度上的大小不一致,但其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。
-
如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。
-
广播后的张量形状是每个维度上大小的最大值。
import torch
# 示例 1: 形状不同的张量相加
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([1, 2, 3])
# b 会被广播成 [[1, 2, 3], [1, 2, 3]]
result = a + b
print(result)
# 输出:
# tensor([[ 2, 4, 6],
# [ 5, 7, 9]])
# 示例 2: 形状不同的张量相乘
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = torch.tensor([1, 2])
# b 会被广播成 [[1, 2], [1, 2], [1, 2]]
result = a * b
print(result)
# 输出:
# tensor([[ 1, 4],
# [ 3, 8],
# [ 5, 12]])
# 示例 3: 形状不同的张量相加
a = torch.tensor(