提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
文章目录
PyTorch 的广播机制可以使不同形状的张量能够进行逐元素操作,如加减乘除,前提是这些张量的形状是兼容的。广播机制遵循一套特定的规则来扩展较小的张量以匹配较大张量的形状。
广播机制规则
- 从右向左对齐形状:比较两个张量的形状时,从最右边的维度开始对齐。
- 维度匹配:对应位置的两个维度要么相等,要么其中一个为 1或空,否则运算会失败。
- 维度扩展:当一个维度为 1 时,可以沿该维度复制,使其与另一个张量的相应维度匹配。
广播机制示例
假设有两个张量 A
和 B
,以下是一些示例来解释广播规则:
示例 1: 形状为(4, 1)和()可以广播
- 张量 A 形状为
(4, 1)
- 张量 B 形状为
()
对齐后的形状均为:(4, 1)
在这种情况下,B 是一个标量,它会被扩充出两个维度。广播过程详述为:
- 按照规则从右开始对齐维度数字,A的的第一个维度数字为1,B第一个维度数字为空,此时满足扩充条件“其中一方为空”,B扩充为形状为(1),数据为[1]的张量;
- 往左继续对齐维度数字,A的的第二个维度数字为4,B第二个维度数字为空,此时满足扩充条件“其中一方为空”,B扩充为形状为(4,1),数据为
[
[1],
[1],
[1],
[1],
]
的张量;
- 往左继续对齐,发现AB都没有再高的维度。最终B被广播成
(4, 1)
的形状。
import torch
A = torch.tensor([[1], [2], [3], [4]]) # 形状 (4, 1)
B = torch.tensor(1) # 形状 (4,)
result = A + B
print(result)
# 输出:
# tensor([[2],
# [3],
# [4],
# [5]])
示例 2: 形状为(4, 1)和(4,)可以广播
- 张量 A 形状为
(4, 1)
- 张量 B 形状为
(4,)
对齐后的形状均为:(4, 1)
在这种情况下,B 的第二个维度可以扩展为1,因此可以广播成 (4, 1)
。结果张量的形状为 (4, 1)
。
import torch
A = torch.tensor([[1], [2], [3], [4]]) # 形状 (4, 1)
B = torch.tensor([5, 6, 7, 8]) # 形状 (4,)
result = A * B
print(result)
# 输出:
# tensor([[ 5, 6, 7, 8],
# [10, 12, 14, 16],
# [15, 18, 21, 24],
# [20, 24, 28, 32]])
示例 3: 形状为(1, 3, 1)和(3, 1)可以广播
- 张量 A 形状为
(1, 3, 1)
- 张量 B 形状为
(3, 1)
对齐后的形状均为:(1, 3, 1)
在这种情况下,B 的形状被视为 (1, 3, 1)
,结果张量的形状为 (1, 3, 1)
。
A = torch.tensor([[[1], [2], [3]]]) # 形状 (1, 3, 1)
B = torch.tensor([[4], [5], [6]]) # 形状 (3, 1)
result = A * B
print(result)
# 输出:
# tensor([[[ 4],
# [ 5],
# [ 6]],
# [[ 8],
# [10],
# [12]],
# [[12],
# [15],
# [18]]])
示例 4: 形状为(2, 1, 3)和(1, 4, 1)可以广播
- 张量 A 形状为
(2, 1, 3)
- 张量 B 形状为
(1, 4, 1)
对齐后的形状均为:(2, 4, 3)
A 的第二个维度可以扩展,B 的第一个和第三个维度可以扩展。结果张量的形状为 (2, 4, 3)
。
A = torch.tensor([[[1, 2, 3]], [[4, 5, 6]]]) # 形状 (2, 1, 3)
B = torch.tensor([[[1], [2], [3], [4]]]) # 形状 (1, 4, 1)
result = A * B
print(result)
# 输出:
# tensor([[[[ 1, 2, 3],
# [ 2, 4, 6],
# [ 3, 6, 9],
# [ 4, 8, 12]],
# [[ 4, 5, 6],
# [ 8, 10, 12],
# [12, 15, 18],
# [16, 20, 24]]]])
示例 5: 形状为(2, 3)和(3, 2)无法广播
- 张量 A 形状为
(2, 3)
- 张量 B 形状为
(3, 2)
A和B无法对齐。
在这种情况下,A 和 B 的对应维度都不匹配,既没有维度为1,也没有相等的维度,因此无法广播。
try:
A = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状 (2, 3)
B = torch.tensor([[1, 2], [3, 4], [5, 6]]) # 形状 (3, 2)
result = A * B
except RuntimeError as e:
print(f"无法广播: {e}")
# 输出:
# RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1
总结
广播机制在以下情况下适用:
- 从右向左对齐张量的形状。
- 如果某一维度不匹配且不为1或空,则无法广播。
- 如果某一维度为1或空,则可以扩展该维度以匹配另一个张量在该维度的大小。
广播机制使得不同形状的张量能够进行灵活的逐元素运算,从而简化代码和提高计算效率。