第34个方法
torch.index_select(input, dim, index, *, out=None) → Tensor
此方法的作用是,根据dim取出input中的index对应的元素,并且返回一个tensor,首先是参数介绍
input(Tensor)
:输入的tensor。dim(int)
:指定我们要进行索引选择的维度。index(LongTensor)
:索引。out
:输出tensor。
使用方法:
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
-
可以看到,当dim为0时,直接抽取了第0维上的元素,抽取的是索引为0和2的元素,符合预期,当dim为1时,抽取第1维上的元素,索引为0和2。
-
返回的结果Tensor和input具有相同的维度。在指定的dim维度上的长度为索引(index)的长度,在上面公式中为2,其实这个很好理解,因为索引里有几个元素,,我们就在input中指定维度上取几个元素。而在其它维度上和原tensor长度相等。
注意:返回的tensor和原tensor并不使用相同的内存。并且如果out的形状和预期不相等,pytorch会默默的将结果修改为正确的,并且在必要时重新分配内存。