torch.gather是一个PyTorch函数,用于按照给定的索引从输入张量中收集指定的元素。它接受三个参数:
input:输入张量,其形状为 (B, N, C),其中 B 表示批次大小,N 表示每个批次中的元素数,C 表示每个元素的特征数。
dim:表示需要收集的维度。例如,如果 dim=1,则表示沿着第二个维度收集元素。
index:一个张量,包含了从输入张量 input 中收集元素所需的索引。它的形状为 (B, M),其中 M 表示每个批次中要收集的元素数量。
函数会返回一个形状为 (B, M, C) 的张量,其中第 i 个批次的第 j 个元素是输入张量 input 中第 i 个批次的第 index[i][j] 个元素。
这个函数通常用于实现集中式注意力机制,其中 input 表示输入特征,index 表示每个元素需要聚焦的位置。
以下是几个使用torch.gather函数的示例:
示例1:获取每个批次的最大值
import torch
# 创建一个大小为 (2, 4, 3) 的输入张量
input_tensor = torch.tensor([
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
[[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]]
])
# 沿着第二个维度获取每个批次的最大值的索引
max_indices = torch.argmax(input_tensor, dim=1)
# 使用 torch.gather 收集每个批次的最大值
max_values = torch.gather(input_tensor, dim=1, index=max_indices.unsqueeze(-1).repeat(1, 1, input_tensor.shape[-1]))
print(max_values)
# 输出:
# tensor([[[10, 11, 12]],
# [[22, 23, 24]]])
这个示例中,我们首先使用torch.argmax函数获取每个批次的最大值的索引,然后使用torch.gather函数收集这些最大值。我们在索引张量的最后一个维度上添加了一个额外的维度,以便torch.gather可以正确地收集最大值。
示例2:从每个批次的张量中收集指定位置的元素
import torch
# 创建一个大小为 (2, 4, 3) 的输入张量
input_tensor = torch.tensor([
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
[[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]]
])
# 创建一个大小为 (2, 2) 的索引张量
indices = torch.tensor([[1, 3], [0, 2]])
# 从每个批次的张量中收集指定位置的元素
gathered_values = torch.gather(input_tensor, dim=1, index=indices.unsqueeze(-1).repeat(1, 1, input_tensor.shape[-1]))
print(gathered_values)
# 输出:
# tensor([[[ 4, 5, 6],
# [10, 11, 12]],
#
# [[13, 14, 15],
# [19, 20, 21]]])
在这个示例中,我们创建了一个索引张量,指定了要从每个批次的张量中收集的元素的位置。我们然后使用torch.gather函数收集这些元素,并将结果保存在一个新的张量中。
一个复杂一些的例子
indices = torch.unsqueeze(tokens[:,
(prev_context_length + 1):(context_length + 1)],2)
output_log_probs[:,prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
这段代码主要实现了一个从一个Tensor中按照给定索引收集元素的操作。
首先,代码通过切片操作tokens[:, (prev_context_length + 1):(context_length + 1)]获取了一个张量,这个张量是输入log_probs张量中需要收集元素的索引。这个张量的形状应该是(B, L, 1),其中B表示batch大小,L表示序列长度,因为torch.gather需要输入的索引张量的形状是(B, L, 1)。
接下来,使用torch.unsqueeze函数在这个张量的最后一个维度上增加了一个维度,将其形状变成(B, L, 1, 1)。这是为了方便torch.gather函数的使用,因为torch.gather需要在索引张量的最后一个维度上收集元素。
然后,代码使用torch.gather函数,将log_probs张量中需要收集的元素按照索引张量indices收集起来,并将结果存储在output_log_probs张量的相应位置中。这里,dim=2表示在log_probs张量的第3个维度(即特征数)上收集元素。最后,使用squeeze函数去除结果张量中的冗余维度,得到最终的输出结果。