Pytorch常用的函数(九)torch.gather()用法

Pytorch常用的函数(九)torch.gather()用法

torch.gather() 就是在指定维度上收集value。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

1、举例直观理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

# 2、我们有index_tensor如下
>>> index_tensor = torch.tensor(
       [[[0, 0, 0, 0],
         [2, 2, 2, 2]],
         
        [[0, 0, 0, 0],
         [2, 2, 2, 2]]]
)	

# 3、我们通过torch.gather()函数获取out_tensor
>>> out_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
tensor([[[ 0,  1,  2,  3],
         [ 8,  9, 10, 11]],
         
        [[12, 13, 14, 15],
         [20, 21, 22, 23]]])

我们以out_tensor中[0,1,0]=8为例,解释下如何利用dim和index,从input_tensor中获得8。

在这里插入图片描述

根据上图,我们很直观的了解根据 index ,在 inputdim 维度上收集 value的过程。

  • 假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可;
  • 但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报错 ;
  • 所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。在上面示例中,由于dim=1,那么我们就替换索引列表第1个值,即[i,dim,k],因此由原来的[0,1,0]替换为[0,2,0]后,再去input_tensor中取值。
  • pytorch官方文档的写法如下,同一个意思。
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

2、反推法再理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

# 2、假设我们要得到out_tensor如下
>>> out_tensor
tensor([[[ 0,  1,  2,  3],
         [ 8,  9, 10, 11]],
         
        [[12, 13, 14, 15],
         [20, 21, 22, 23]]])# 3、如何知道dim 和 index_tensor呢? 
# 首先,我们要记住:out_tensor的shape = index_tensor的shape

# 从 output_tensor 的第一个位置开始:
# 此时[i, j, k]一样,看不出来 dim 应该是多少
output_tensor[0, 0, :] = input_tensor[0, 0, :] = 0
# 同理可知,此时index都为0
output_tensor[0, 0, 1] = input_tensor[0, 0, 1] = 1
output_tensor[0, 0, 2] = input_tensor[0, 0, 2] = 2
output_tensor[0, 0, 3] = input_tensor[0, 0, 3] = 3

# 我们从下一行的第一个位置开始:
# 这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2
output_tensor[0, 1, 0] = input_tensor[0, 2, 0] = 8
# 同理可知,此时index都为2
output_tensor[0, 1, 1] = input_tensor[0, 2, 1] = 9
output_tensor[0, 1, 2] = input_tensor[0, 2, 2] = 10
output_tensor[0, 1, 3] = input_tensor[0, 2, 3] = 11

# 根据上面推导我们易知dim=1,index_tensor为:
>>> index_tensor = torch.tensor(
       [[[0, 0, 0, 0],
         [2, 2, 2, 2]],
         
        [[0, 0, 0, 0],
         [2, 2, 2, 2]]]
)	

3、实际案例

在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中,多次使用了torch.gather() 函数。

在MAE中根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分mask 掉。采样25%作为unmasked tokens过程中,使用了torch.gather() 函数。

# models_mae.py

import torch

def random_masking(x, mask_ratio=0.75):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))  # 计算unmasked的片数
    # 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】
    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

    # sort noise for each sample【核心代码】
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    # 利用torch.gather()从源tensor中获取25%的unmasked tokens
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore

if __name__ == '__main__':
    x = torch.arange(64).reshape(1, 16, 4)
    random_masking(x)
# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量,实际上一般为(img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196
# 4是一个patch中像素个数,这里只是模拟,实际上一般为(in_chans * patch_size * patch_size = 3*16*16 = 768)
>>> x = torch.arange(64).reshape(1, 16, 4) 
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15],
         [16, 17, 18, 19], # 4
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31],
         [32, 33, 34, 35],
         [36, 37, 38, 39],
         [40, 41, 42, 43], # 10
         [44, 45, 46, 47],
         [48, 49, 50, 51], # 12
         [52, 53, 54, 55], # 13
         [56, 57, 58, 59],
         [60, 61, 62, 63]]])
# dim=1, index相当于index_tensor
>>> index
tensor([[[10, 10, 10, 10],
         [12, 12, 12, 12],
         [ 4,  4,  4,  4],
         [13, 13, 13, 13]]])


# x_masked(从源tensor即x中,随机获取25%(4个patch)的unmasked tokens)     
>>> x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],
         [48, 49, 50, 51],
         [16, 17, 18, 19],
         [52, 53, 54, 55]]])
  • 5
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: torch.gather函数PyTorch中的一个函数,用于在给定维度上按索引从输入张量中提取元素并构建新的张量。 torch.gather函数的语法为:torch.gather(input, dim, index, out=None)。 参数说明: - input:输入张量,即需要从中提取元素的张量。 - dim:要在哪个维度上进行提取操作。 - index:一个包含需要提取元素的索引的张量。 - out:一个可选的输出张量。 在torch.gather函数中,我们会按照dim指定的维度,在input张量上进行提取操作。提取操作是根据index张量中给定的索引值来进行的。最终会构建一个新的张量,其中包含了根据索引从input张量中提取出来的元素。 例如,如果input是一个2维张量,shape为(3,4),而index是一个1维张量,shape为(3,),则dim的取值范围为[0, 1]。如果dim=0,那么提取操作将沿着第一个维度进行,在每一列上按照index张量中对应的值进行元素的提取。如果dim=1,那么提取操作将沿着第二个维度进行,在每一行上按照index张量中对应的值进行元素的提取。 使用torch.gather函数可以灵活地根据给定的索引从输入张量中提取出所需的元素,这对于实现一些特定需求的操作非常有用。例如,可以在处理图像分类任务时,根据预测的类别标签,从softmax输出概率中提取出对应类别的概率,进而用于计算损失函数或者评估模型性能等。 ### 回答2: torch.gather函数是一个PyTorch中的操作函数,用于在指定维度上根据索引获取原始张量中的元素。这个函数的使用方式为: output = torch.gather(input, dim, index, out=None, sparse_grad=False) 其中,input是原始的张量,dim是指定的维度,index是需要提取的元素的索引。函数会根据dim指定的维度,在input张量中提取index中指定的元素,并返回一个新的张量output。 例如,假设input是一个3x4的二维张量,index是一个2x3的二维张量,dim的取值为1,那么torch.gather函数会在input的第1个维度上根据index中的元素索引,提取相应的元素。最终得到的output是一个2x3的张量。 torch.gather函数在很多机器学习任务中非常有用。例如,在序列标注任务中,我们可以使用torch.gather函数根据标签索引来选择对应的预测结果。在图像分类任务中,我们可以根据类别索引使用torch.gather函数进行结果的选择。此外,在自然语言处理任务中,torch.gather函数也可以用来根据单词的索引来选择对应的词向量。 需要注意的是,所提取的元素的维度必须与index的维度一致,否则会引发异常。此外,dim的取值必须在0到input的维度之间,否则也会引发异常。如果不指定out参数,函数会返回一个新的张量作为输出,如果指定了out参数,则会把提取的结果保存到指定的张量中。最后,如果sparse_grad为True,则会返回一个稀疏梯度,否则返回一个密集梯度。 总之,torch.gather函数提供了一种方便和高效地根据索引提取元素的方式,广泛应用于各种机器学习任务中。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值