最近开始接触pytorch,发现torch.index_select不是很好理解,就查了一下文档:
torch.index_select(input, dim, index, out=None) → Tensor
-
input :输入一个张量
-
dim :索引所依赖的维度
-
index :索引的index
-
out :返回的张量,默认为None
torch.index_select()在input的张量在dim的维度上按照index索引并返回一个新的张量。这样子不太好理解,我们拿一个例子来看看:
导入torch包,创建一个3×4的二维张量x
传入x,并在第0维上(这里是二维,所以直接是在行上)按照[0,2]索引(也就是选择index为0和2的行,也就是第1行和第三行)
同理,这里只是把维度改成了1,这里就是索引x的第一列和第三列。
这样子写也能实现一样的效果。
另外,官方文档上面有这样一段描述:
意思就是说,返回的张量和输入的张量具有相同的维数。然后返回的张量中,第dim维(就是代码中的dim)跟的大小跟代码中的index一样长,其它维的大小跟输入张量保持一致,还是上面的例子如下图:
输入张量是3×4,这是在第1维(此处就是列)进行索引,indices长度为2。输出张量是第0维是3,与输入张量一致,第0维是2,与indices一致。