label:
tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])
idx:
tensor([[0, 5, 7],
[3, 5, 6],
[2, 5, 8],
[1, 6, 9]])
torch.gather(label,dim = 1,index = idx.long())
#输出:
out =
tensor([[100, 105, 107],
[103, 105, 106],
[102, 105, 108],
[101, 106, 109]])
#label.shape = 4*10 , 为4行10列的tensor,参数dim=1指的是在label的第一个维度取值
#也就是在label的列取值, idx指的就是dim=1的索引,也就是label列索引
#如果dim=0,则idx指的就是dim=0的索引,也就是label行索引
#out的size跟idx的size一致
#idx(0,0)的值为0,由于dim=1,所以应该取出label第0行,第0列的值,所以out(0,0) = 100
#idx(0,1)的值为5,由于dim=1,所以应该取出label第0行,第5列的值,所以out(0,1) = 105
#idx(0,2)的值为7,由于dim=1,所以应该取出label第0行,第7列的值,所以out(0,2) = 107
## 接下来,进入到idx的第二行,所以,label也应该进入到第二行,out也是如此。
#idx(1,0)的值为3,由于dim=1,所以应该取出label第1行,第3列的值,所以out(1,0) = 103
#idx(1,1)的值为5,由于dim=1,所以应该取出label第1行,第1列的值,所以out(1,1) = 105
#idx(1,2)的值为6,由于dim=1,所以应该取出label第1行,第6列的值,所以out(1,0) = 106
#看出规律了吗,因为dim=1,idx的值就是要取的label的列索引,idx的行号,就是要取的label的行索引
##想必看到这里,你已经明白了gather的作用
##我们再来测试一下dim=0
torch.gather(label,dim = 0,index = idx.long()):
#报错如下:
#RuntimeError Traceback (most #recent call last)
#~/Desktop/Study/手写数字.py in
#----> 346 torch.gather(label,dim = 0,index = idx.long())
#
#RuntimeError: index 5 is out of bounds for dimension 0 with size 4
##哦豁,竟然报错了,因为我们的idx中,有大于4的值,为什么是大于4呢,因为label的第一个维度为4,理解了吧
#我们重新换个idx1来试下:
idx1 = torch.tensor([[0,1,0,3,2,1,0,1]])
label:
tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])
torch.gather(label,dim = 0,index = idx1.long())
#输出:
out =
tensor([[100, 101, 102, 103, 104, 105, 106, 107]])
#还是老规矩,out的shape,一定是和idx1一样的
#idx1中没有大于4的值,因为label的shape为4*10,所以当dim=0时,idx中的值表示的就是label的行号了,所以必须小于4.
#接下来我们来分析为什么out是这样的结果:
#idx1.shape = (1,8)
#idx1(0,0)的值为0,dim=0,所以应该取出label的第0行,第0列的值,所以out(0,0) = 100
#idx1(0,1)的值为1,dim=0,所以应该取出label的第1行,第1列的值,所以out(0,1) = 101
#idx1(0,2)的值为0,dim=0,所以应该取出label的第0行,第2列的值,所以out(0,2) = 102
#idx1(0,3)的值为3,dim=0,所以应该取出label的第3行,第3列的值,所以out(0,3) = 103
#idx1(0,4)的值为2,dim=0,所以应该取出label的第2行,第4列的值,所以out(0,4) = 104
#看出来了吗,由于dim=0,现在idx1的值,就是要取的label的行号,idx1的列索引,就是要取的label的列号
#如果到这里还不理解的话,可以再更换下idx试试。
看到这里,如果让你恍然大悟的话,点个赞吧😁