介绍
官网介绍:https://pytorch.org/docs/stable/generated/torch.Tensor.gather.html
torch.Tensor.gather
是PyTorch中的一个函数,它根据索引从输入张量中收集值。
示例代码1
以下是一个使用torch.Tensor.gather
的示例:
import torch
# 创建一个输入张量
input = torch.tensor([[1, 2], [3, 4]])
# 创建一个索引张量
index = torch.tensor([[0, 0], [1, 0]])
# 使用gather函数
output = input.gather(1, index)
print(output)
# tensor([[1, 1],
# [4, 3]])
解析:
- index[0,0]=0对应的值为input[0][0]=1
- index[0,1]=0对应的值为input[0][0]=1
- index[1,0]=1对应的值为input[1][1]=4
- index[1,1]=0对应的值为input[1][0]=3
在上述代码中,input.gather(1, index)
会沿着维度1(列)收集值。索引张量index
中的每个值指定了在相应位置收集哪个元素。因此,output
的值为[[1, 1], [4, 3]]
。
示例代码2
以下是一个使用torch.Tensor.gather
的3维tensor示例:
import torch
# 创建一个输入张量
input = torch.arange(0,8).view(2, 2, 2)
# tensor([[[0, 1],
# [2, 3]],
# [[4, 5],
# [6, 7]]])
# 创建一个索引张量
index = torch.tensor([[[0,0]],[[1,0]]])
# 使用gather函数
output = input.gather(1, index)
print(output)
# tensor([[[0, 1]],
# [[6, 5]]])
解析:
- index[0,0,0]=0对应的值为input[0][0][0]=0
- index[0,0,1]=0对应的值为input[0][0][1]=1
- index[1,0,0]=1对应的值为input[1][1][0]=6
- index[1,0,1]=0对应的值为input[1][0][1]=5
从上面可以看出index对应的值,被用作了input的第1维索引,其他维度不变。
错误示例
# 错误示例
index = torch.tensor([[[0,0]],[[10,0]]])
input.gather(1, index) # RuntimeError: index 10 is out of bounds for dimension 1 with size 2
input在第一维度大小为2,所以不能用10来当做索引。