1 机制
torch.broadcast_tensors
是 PyTorch 中的一个将tensor扩充的函数
在a, b = torch.broadcast_tensors(x, y)中,是将x与y的形状“黏合”起来,从而组成两个形状相同的tensor:a、b,
在不复制数据的情况下就能进行运算,整个过程可以做到避免无用的复制,达到更高效的运算。
广播机制实际上是在运算过程中,去处理两个形状不同向量的一种手段。
pytorch中的广播机制和numpy中的广播机制一样, 因为都是数组的广播机制。
2 一顿乱试
示例:
import torch
w21 = torch.randint(low=1, high=21, size=(2, 1))
w12 = torch.randint(low=1, high=21, size=(1, 2))
w12a = torch.randint(low=1, high=21, size=(1, 2))
w13 = torch.randint(low=1, high=21, size=(1, 3))
w14 = torch.randint(low=1, high=21, size=(1, 4))
w41 = torch.randint(low=1, high=21, size=(4, 1))
w42 = torch.randint(low=1, high=21, size=(4, 2))
w22 = torch.randint(low=1, high=21, size=(2, 2))
x211 = torch.randint(low=1, high=21, size=(2, 1, 1))
x121 = torch.randint(low=1, high=21, size=(1, 2, 1))
x112 = torch.randint(low=1, high=21, size=(1, 1, 2))
y = torch.randint(low=1, high=21, size=(3, 2, 1))
# 广播两个张量
a, b = torch.broadcast_tensors(w12, w21) # √
a, b = torch.broadcast_tensors(w21, w12) # √
a, b = torch.broadcast_tensors(w12, w41) # √
a, b = torch.broadcast_tensors(w13, w41) # √
a, b = torch.broadcast_tensors(w22, w41) #×
a, b = torch.broadcast_tensors(w22, w42) #×
a, b
3 两个张量进行广播机制的条件
必要条件:两个张量都至少有一个维度
按从右往左顺序看两个张量的每一个维度,每个对应着的两个维度都需要能够匹配上。
匹配条件:
a.这两个维度的大小相等
b. 某个维度 一个张量有,一个张量没有
c.某个维度 一个张量有,一个张量也有但大小是1
如下例
x321 = torch.randint(low=1, high=21, size= (3, 2, 1))
x2123 = torch.randint(low=1, high=21, size=(2, 1, 2, 3))
a, b = torch.broadcast_tensors(x2123, x321) # √
a, b
a.shape, b.shape
运行结果
a.shape, b.shape
(torch.Size([2, 3, 2, 3]), torch.Size([2, 3, 2, 3]))