PyTorch 中广播机制(Broadcasting)笔记

在 PyTorch 中存在广播(Broadcasting),广播是一种机制,用于自动扩展较小的张量以匹配较大张量的形状,从而使得它们能够进行元素级操作(如加法、减法、乘法等)。广播并不改变张量的实际数据,而是通过虚拟扩展来简化操作。

广播机制的规则

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。

  2. 如果两个张量在某个维度上的大小不一致,但其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。

  3. 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

  4. 广播后的张量形状是每个维度上大小的最大值。

import torch

# 示例 1: 形状不同的张量相加
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([1, 2, 3])
# b 会被广播成 [[1, 2, 3], [1, 2, 3]]
result = a + b
print(result)
# 输出:
# tensor([[ 2,  4,  6],
#         [ 5,  7,  9]])

# 示例 2: 形状不同的张量相乘
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = torch.tensor([1, 2])
# b 会被广播成 [[1, 2], [1, 2], [1, 2]]
result = a * b
print(result)
# 输出:
# tensor([[ 1,  4],
#         [ 3,  8],
#         [ 5, 12]])

# 示例 3: 形状不同的张量相加
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1], [2], [3]])
# a 会被广播成 [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
# b 会被广播成 [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
result = a + b
print(result)
# 输出:
# tensor([[2, 3, 4],
#         [3, 4, 5],
#         [4, 5, 6]])

广播机制在张量乘法中的应用

在进行张量乘法时,广播机制也可以简化操作,尤其是在批次维度不同时:

import torch

# 张量 A 的形状是 (2, 3, 4)
A = torch.randn(2, 3, 4)

# 张量 B 的形状是 (4, 5)
B = torch.randn(4, 5)

# 使用广播机制进行张量乘法
# B 会被广播成 (2, 4, 5)
result = torch.matmul(A, B)
print(result.shape)
# 输出:
# torch.Size([2, 3, 5])

判断两个张量是否可以进行广播操作

主要遵循以下规则:

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。
  2. 从最后一个维度开始,逐个维度向前检查:
  • 如果两个张量在某个维度上的大小相同,或者其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。
  • 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

具体步骤
假设有两个张量 A 和 B,其形状分别为 shapeA 和 shapeB。

  1. 对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。
  2. 逐维度检查:从最后一个维度开始,逐个维度向前检查:
  • 如果两个维度大小相同,或其中一个维度大小为 1,则该维度可以进行广播。
  • 如果两个维度大小既不相同且都不为 1,则无法进行广播。

以下是一个 Python 函数,用于判断两个张量是否可以进行广播操作:

def can_broadcast(shapeA, shapeB):
    # 对齐维度
    lenA, lenB = len(shapeA), len(shapeB)
    if lenA < lenB:
        shapeA = (1,) * (lenB - lenA) + shapeA
    elif lenB < lenA:
        shapeB = (1,) * (lenA - lenB) + shapeB

    # 逐维度检查
    for dimA, dimB in zip(shapeA, shapeB):
        if dimA != dimB and dimA != 1 and dimB != 1:
            return False
    return True

# 示例
shapeA = (2, 3, 4)
shapeB = (4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,不相等且都不为 1,无法广播。
第二个维度:3 和 4,不相等且都不为 1,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (2, 3, 4)
shapeB = (1, 4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,无法广播。
第二个维度:3 和 4,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (3, 4)
shapeB = (2, 1, 4)
print(can_broadcast(shapeA, shapeB))  # 输出: True
'''
对齐维度:(1, 3, 4) 和 (2, 1, 4)
'''
shapeA = (5,)
shapeB = (1, 5)
print(can_broadcast(shapeA, shapeB)) # 输出: True
'''
对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。得到 (1, 5) 和 (1, 5)
'''

广播机制结合张量乘法例子

示例 1: 形状为 (2, 3, 4) 和 (4,) 的张量

import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (4,)
B = np.random.rand(4)

# 广播机制会将 B 扩展为 (1, 1, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape) # A.shape: (2, 3, 4)
print("B.shape:", B.shape) # B.shape: (4,)
print("result.shape:", result.shape) # 逐元素乘法操作成功进行,结果的形状为 (2, 3, 4)

示例 2: 形状为 (2, 3, 4) 和 (3, 4) 的张量

import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 B 扩展为 (1, 3, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

示例 3: 形状为 (2, 1, 3) 和 (3, 4) 的张量

import numpy as np

# 张量 A 的形状为 (2, 1, 3)
A = np.random.rand(2, 1, 3)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 A 扩展为 (2, 1, 3),B 扩展为 (1, 3, 4)
# 矩阵乘法会在最后两个维度上进行
result = np.matmul(A, B)

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

广播机制允许我们在不同形状的张量之间进行逐元素操作或矩阵操作,而无需显式地扩展张量的形状。这大大简化了张量操作的复杂性,并提高了代码的可读性和效率。

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值