在PyTorch中,广播(Broadcasting)是一种用于在不同形状的张量之间执行逐元素操作的机制。在进行逐元素操作时,如果两个张量的形状不完全匹配,PyTorch会自动使用广播机制来进行形状的扩展,使得两个张量的形状相容,从而进行逐元素操作。
广播机制遵循以下规则:
1. 当两个张量的维度个数不同,将维度较少的张量通过在前面插入长度为1的维度来扩展,直到两个张量具有相同的维度个数。
2. 当两个张量在某个维度上的长度不匹配时,如果其中一个张量在该维度上的长度为1,那么可以通过复制该张量的值来扩展该维度,使得两个张量在该维度上的长度相同。
3. 如果以上两个步骤无法使得两个张量的形状匹配,那么会抛出形状不兼容的错误。
广播机制可以应用于一系列的逐元素操作,例如加法、减法、乘法、除法等。通过广播机制,我们可以方便地对形状不同的张量进行逐元素操作,避免了手动扩展张量的操作。
例子1:
import torch
# 创建两个形状不同的张量
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状为(2, 3)
b = torch.tensor([10, 20, 30]) # 形状为(3,)
# 使用广播机制进行逐元素相加
c = a + b # 广播机制会自动将b扩展为(2, 3),使得a和b的形状相同
print(c)
输出结果:
tensor([[11, 22, 33],
[14, 25, 36]])
例子2(注意看shape):
'''
·每个张量至少一个维度
·满足右对齐
·torch.rand(2,1,1)+torch.rand(3)
'''
import torch
a = torch.rand(2, 3) # 2 * 3
b = torch.rand(3) # 1 * 3
c = a + b # 2 * 3
print(c)
print(c.shape)
a = torch.rand(2, 1, 1, 3) # 2 * 1 * 1 * 3
b = torch.rand(4, 2, 3) # 1 * 4 * 2 * 3
c = a + b # 2 * 4 * 2 * 3
print(c)
print(c.shape)
输出结果:
tensor([[0.8484, 0.7692, 1.4322],
[0.8699, 0.9497, 1.3924]])
torch.Size([2, 3])
tensor([[[[1.5291, 1.0000, 0.8863],
[1.1274, 1.4687, 0.5827]],
[[0.9342, 1.5905, 0.5801],
[1.2628, 0.8225, 0.4521]],
[[1.2618, 1.1720, 1.2192],
[1.1753, 1.2166, 0.4413]],
[[0.8598, 0.5976, 0.5721],
[1.5772, 1.5361, 0.6881]]],
[[[1.8213, 1.2127, 0.9392],
[1.4195, 1.6814, 0.6356]],
[[1.2264, 1.8032, 0.6330],
[1.5550, 1.0352, 0.5050]],
[[1.5539, 1.3847, 1.2721],
[1.4674, 1.4293, 0.4942]],
[[1.1520, 0.8103, 0.6250],
[1.8694, 1.7488, 0.7410]]]])
torch.Size([2, 4, 2, 3])
Process finished with exit code 0