torch.index_select函数,顾名思义就是根据index索引在input输入张量中选择某些特定的元素,下面介绍该函数的参数。
torch.index_select(input, dim, index, out=None):
- input:输入Tensor,在该Tensor上根据index和dim进行切片;
- dim:切片维度,在input这个Tensor的哪个维度上进行dinex索引;
- index:Tensor.LongTensor类型的1-D Tensor,在dim维度上需要索引的下标(自己尝试过非1-D的index,结果报错 Index is supposed to be
1-dimensional,如有不对欢迎指正);- out:用来承载函数的返回值(也可以直接用变量x=torch.index_select(input, dim, index)进行承载,不需要out参数)
import torch
x = torch.rand(3,5)
index = torch.LongTensor([2,0])
# 如果想在x的第一个维度上选择x[2]和x[0]
y = torch.index_select(x, dim=0, index=index)
# 如果想在x的第二个维度上选择,即x[...,2]和x[...,0]
y = torch.index_select(x, dim=1, index=index)
# 另外,也可以用以下方法
y = x.new()
torch.index_select(x, dim=0, index=index, out=y)