pytorch的广播机制

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


PyTorch 的广播机制可以使不同形状的张量能够进行逐元素操作,如加减乘除,前提是这些张量的形状是兼容的。广播机制遵循一套特定的规则来扩展较小的张量以匹配较大张量的形状。

广播机制规则

  1. 从右向左对齐形状:比较两个张量的形状时,从最右边的维度开始对齐。
  2. 维度匹配:对应位置的两个维度要么相等,要么其中一个为 1或空,否则运算会失败。
  3. 维度扩展:当一个维度为 1 时,可以沿该维度复制,使其与另一个张量的相应维度匹配。

广播机制示例

假设有两个张量 AB,以下是一些示例来解释广播规则:

示例 1: 形状为(4, 1)和()可以广播
  • 张量 A 形状为 (4, 1)
  • 张量 B 形状为 ()

对齐后的形状均为:(4, 1)

在这种情况下,B 是一个标量,它会被扩充出两个维度。广播过程详述为:

  • 按照规则从右开始对齐维度数字,A的的第一个维度数字为1,B第一个维度数字为空,此时满足扩充条件“其中一方为空”,B扩充为形状为(1),数据为[1]的张量;
  • 往左继续对齐维度数字,A的的第二个维度数字为4,B第二个维度数字为空,此时满足扩充条件“其中一方为空”,B扩充为形状为(4,1),数据为
[
	[1],
	[1],
	[1],
	[1],
]

的张量;

  • 往左继续对齐,发现AB都没有再高的维度。最终B被广播成 (4, 1)的形状。
import torch

A = torch.tensor([[1], [2], [3], [4]])  # 形状 (4, 1)
B = torch.tensor(1)         # 形状 (4,)
result = A + B
print(result)
# 输出:
# tensor([[2],
#         [3],
#         [4],
#         [5]])
示例 2: 形状为(4, 1)和(4,)可以广播
  • 张量 A 形状为 (4, 1)
  • 张量 B 形状为 (4,)

对齐后的形状均为:(4, 1)

在这种情况下,B 的第二个维度可以扩展为1,因此可以广播成 (4, 1)。结果张量的形状为 (4, 1)

import torch

A = torch.tensor([[1], [2], [3], [4]])  # 形状 (4, 1)
B = torch.tensor([5, 6, 7, 8])         # 形状 (4,)
result = A * B
print(result)
# 输出:
# tensor([[ 5,  6,  7,  8],
#         [10, 12, 14, 16],
#         [15, 18, 21, 24],
#         [20, 24, 28, 32]])
示例 3: 形状为(1, 3, 1)和(3, 1)可以广播
  • 张量 A 形状为 (1, 3, 1)
  • 张量 B 形状为 (3, 1)

对齐后的形状均为:(1, 3, 1)

在这种情况下,B 的形状被视为 (1, 3, 1),结果张量的形状为 (1, 3, 1)

A = torch.tensor([[[1], [2], [3]]])  # 形状 (1, 3, 1)
B = torch.tensor([[4], [5], [6]])    # 形状 (3, 1)
result = A * B
print(result)
# 输出:
# tensor([[[ 4],
#          [ 5],
#          [ 6]],
#         [[ 8],
#          [10],
#          [12]],
#         [[12],
#          [15],
#          [18]]])
示例 4: 形状为(2, 1, 3)和(1, 4, 1)可以广播
  • 张量 A 形状为 (2, 1, 3)
  • 张量 B 形状为 (1, 4, 1)

对齐后的形状均为:(2, 4, 3)

A 的第二个维度可以扩展,B 的第一个和第三个维度可以扩展。结果张量的形状为 (2, 4, 3)

A = torch.tensor([[[1, 2, 3]], [[4, 5, 6]]])  # 形状 (2, 1, 3)
B = torch.tensor([[[1], [2], [3], [4]]])      # 形状 (1, 4, 1)
result = A * B
print(result)
# 输出:
# tensor([[[[ 1,  2,  3],
#           [ 2,  4,  6],
#           [ 3,  6,  9],
#           [ 4,  8, 12]],
#          [[ 4,  5,  6],
#           [ 8, 10, 12],
#           [12, 15, 18],
#           [16, 20, 24]]]])
示例 5: 形状为(2, 3)和(3, 2)无法广播
  • 张量 A 形状为 (2, 3)
  • 张量 B 形状为 (3, 2)

A和B无法对齐。

在这种情况下,A 和 B 的对应维度都不匹配,既没有维度为1,也没有相等的维度,因此无法广播。

try:
    A = torch.tensor([[1, 2, 3], [4, 5, 6]])  # 形状 (2, 3)
    B = torch.tensor([[1, 2], [3, 4], [5, 6]])  # 形状 (3, 2)
    result = A * B
except RuntimeError as e:
    print(f"无法广播: {e}")
# 输出:
# RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

总结

广播机制在以下情况下适用:

  1. 从右向左对齐张量的形状。
  2. 如果某一维度不匹配且不为1或空,则无法广播。
  3. 如果某一维度为1或空,则可以扩展该维度以匹配另一个张量在该维度的大小。

广播机制使得不同形状的张量能够进行灵活的逐元素运算,从而简化代码和提高计算效率。

  • 22
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值