1.返回值和参数解释
按照官网的解释
- return(Tensor) 该方法返回一个Tensor,输出结果是按照输入的Tensor进行索引得到
-
input(Tensor)–输入张量。(原Tensor,得到的结果就是从该Tensor中产生)
-
dim(int)–我们索引的维度。(常用值:0/1 表示:按行/列索引)
-
index(LongTensor)–包含要索引的索引的一维张量。(选择原Tensor某些行或者某些列)
-
out(Tensor,可选)–输出张量。
2.使用方法以及实例
import torch
x = torch.randn(4,3)
print('x = ',x)
# x 是原Tensor
# 0 按行索引
# torch.Tensor([1,2,3]) 截取原Tensor的2,3,4 行
y = torch.index_select(x, 0, torch.LongTensor([1,2,3]))
print('y = ',y)
# 此方式得到的结果与y 相同
print(x.index_select(0,torch.LongTensor([1,2,3])))
# 1 按列索引
# torch.Tensor([0,2]) 截取原Tensor的 1,3 列
z = torch.index_select(x, 1, torch.LongTensor([0,2]))
print('z = ',z)
print(x.index_select(1,torch.LongTensor([0,2])))
结果:
x = tensor([[ 1.4381, 1.4937, 0.1507],
[ 1.0306, -1.8240, 0.6450],
[ 1.5915, -0.0304, -0.5963],
[-0.8520, -0.3590, -1.3559]])
y = tensor([[ 1.0306, -1.8240, 0.6450],
[ 1.5915, -0.0304, -0.5963],
[-0.8520, -0.3590, -1.3559]])
tensor([[ 1.0306, -1.8240, 0.6450],
[ 1.5915, -0.0304, -0.5963],
[-0.8520, -0.3590, -1.3559]])
z = tensor([[ 1.4381, 0.1507],
[ 1.0306, 0.6450],
[ 1.5915, -0.5963],
[-0.8520, -1.3559]])
tensor([[ 1.4381, 0.1507],
[ 1.0306, 0.6450],
[ 1.5915, -0.5963],
[-0.8520, -1.3559]])