import torch as t
a = t.arange(0, 16).view(4, 4)
print(a)
index = t.LongTensor([[0,2,2,3],[1,1,0,3]])
b=a.gather(0, index)
print('-----------------------')
print(b)
b=a.gather(1, index)
print('........................')
print(b)
刚看到gather这个方法,这里面比较懵的是 dim=0,dim=1,其实只要这么理解就好,记住:
dim=0代表行
dim=1代表列
接下来:
a =tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
index = t.LongTensor([[0, 2, 2, 3],
[1, 1, 0, 3]])
二维的数组取值:一般取值是先找第几行,再找第几列,好的,
b=a.gather(0, index) # 这里的dim=0,那么意思就是 index 里面的值代表行,
比如 [0, 2, 2, 3] 就代表依次取第0行,第2行,第2行,第3行,那么列呢?它们取第几个,这里的列就根据其在第几列就取第几列,
[0, 2, 2, 3]中
0是第0列,则取第0列,
2是在第一列,则取第一列,
后面一个2是在第二列,则取第二列,
3是在第三列,则取第三列
所以[0,2,2,3]取值是:
a[0][0],
a[2][1],
a[2][2],
a[3][3]
dim=0,意思就是里面的值代表第几行,列就看对应的值在第几列,就是取第几列
接下来看dim=1
dim=1就是代表里面的值是取第几列,但是取第几行呢?这个行就默认其在第几行就取第几行,官网要求 input.size() == index.size(),这里要求的维度相同,所以
index = t.LongTensor([[0,2,2,3],[1,1,0,3]])
d=a.gather(1, index)
这里[[0,2,2,3],[1,1,0,3]]中[0,2,2,3]在第0行,所以都是在第0行取,第几列就按照里面的值,0代表第0列,2代表第2列,3代表第3列;[1,1,0,3]代表都是在第1行取值,
1代表第1列,0代表第0列,3代表第3列
所以:
dim=0,就代表index中的值是行,至于列,就看其在第几列就取第几列
dim=1,就代表index中的值是列,至于行,就看其在第几行就取第几行
上下两行对照看就比较容易理解了,要么有行,列就按照其所在的列,要么有列,行就按照整体在哪一行;以上解释是由结果推到的,看了很多人的博客和官网,都没有很好理解。