torch.gather()
定义:从原tensor中获取指定dim和指定index的数据
看到这个核心定义,我们很容易想到gather()
的基本想法其实就类似从完整数据中按索引取值般简单,比如下面从列表中按索引取值
lst = [1, 2, 3, 4, 5]
value = lst[2] # value = 3
value = lst[2:4] # value = [3, 4]
上面的取值例子是取单个值或具有逻辑顺序序列的例子,而对于深度学习常用的批量tensor数据来说,我们的需求可能是选取其中多个且乱序的值,此时gather()
就是一个很好的tool,它可以帮助我们从批量tensor中取出指定乱序索引下的数据,因此其用途如下
用途:方便从批量tensor中获取指定索引下的数据,该索引是 高度自定义化的,可乱序的
根据官方文档的显示是根据给出的index的索引坐标来确定要寻找的坐标,然后根据dim来确定将哪个位置坐标换位index中的数字,剩余维度的位置坐标保持不变。 并且input的shape应该是和index的shape保持一致。
到这里再回去看官方文档是不是就能看懂了!!!
【PyTorch】Torch.gather()用法详细图文解释
实战
import torch
import torch.nn.functional as F
a = torch.arange(0,16).view(4,4)
print(a)
index = torch.tensor([[0,1,2,3]])
print(index)
print(a.gather(0, index))
print(a.gather(1, index))
index = torch.tensor([[3, 2, 1, 0]])
print(index)
tensor_1 = a.gather(0, index)
print(tensor_1)
tesnor_2 = a.gather(1,index)
print(tesnor_2)
输出:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[0, 1, 2, 3]])
tensor([[ 0, 5, 10, 15]])
tensor([[0, 1, 2, 3]])
tensor([[3, 2, 1, 0]])
tensor([[12, 9, 6, 3]])
tensor([[3, 2, 1, 0]])
三维实例:
import torch
random_seed = 200
torch.manual_seed(random_seed)
input = torch.randint(0, 100, (2, 3, 4))
print("input:")
print(input)
index = torch.randint(0, 2, (2, 1, 2))
print("index:")
print(index)
output = input.gather(0, index)
print("output:")
print(output)
# 控制台输出
input:
tensor([[[62, 29, 76, 60],
[82, 27, 88, 11],
[57, 50, 71, 9]],
[[33, 71, 66, 34],
[20, 81, 3, 39],
[15, 33, 19, 89]]])
index:
tensor([[[0, 1]],
[[1, 0]]])
output:
tensor([[[62, 71]],
[[33, 29]]])