今天讲二维和三维数组如何索引,重点在二维,三维是从二维拓展而来;另外,末尾我们还介绍了一个工具,专门用来索引,即torch.index_select(),或者a.index_select()。
二维
二维数组很容易索引,例如:
c=torch.rand(2,3)
print(c)
#取行
print(c[[0,1],:])
#取列
print(c[:,[1,2]])
#取行列,所以就是某个元素了。
print(c[[0],[0]])#取(0,0)
print(c[[0,0,0],[0,1,2]])#取(0,0),(0,1),(0,2)
需要注意的是,上述写成下述则不是我们想要的:
index=torch.tensor([[0,0,0],[0,1,2]])
print(c[index])
索引的时候我们最好是使用list,上述可以改成:
index=torch.tensor([[0,0,0],[0,1,2]])
index=index.tolist()
print(c[index])
三维
三维其实可以从二维中递推而得到,基本原理还是一样,之前是有两个位置,所以是一个逗号,现在变成三个位置,所以是2个逗号。
这里只做简单的演示即可:
index_select
a=torch.rand(3,4)
print(a)
indices = torch.tensor([0, 2])
print(a.index_select(0,indices))#取0行和2行。
print(a.index_select(1,indices))#取0列和2列。
gather
最近又看到一个函数,属实懵逼了,上面这些索引好像就够了,但是这个也能索引。
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
比如说我们想要索引出5,7,9。那么用gather怎么实现呢?我们可以指定维度为行,dim=1,然后再指定[2,1,0],这就表示[0,2],[1,1],[2,0]了;那么如果我们先指定维度为列,dim=0,那么似乎就无法索引出[5,7,9]这个顺序了,其会索引出[9,7,5]的顺序。我们来看一下:
index = torch.tensor([[2], [1], [0]])
torch.gather(tensor_0,1,index)
tensor([[5],
[7],
[9]])
可以看到,相当于他会帮你补齐另外一个坐标,[2],那么由于维度是1,从而是[0,2],而不是[2,0]。另外,一个硬性要求是:index必须和tensor_0维度相同,即都是二维的。这个要求我觉得挺奇葩的。
index = torch.tensor([[2,1,0]])
torch.gather(tensor_0,0,index)
tensor([[9, 7, 5]])