对于PyTorch中的torch.gather()的理解

label:
		tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
		        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
		        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
		        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])
		        
idx:
		tensor([[0, 5, 7],
		        [3, 5, 6],
		        [2, 5, 8],
		        [1, 6, 9]])

torch.gather(label,dim = 1,index = idx.long())
#输出:
out = 
		tensor([[100, 105, 107],
		        [103, 105, 106],
		        [102, 105, 108],
		        [101, 106, 109]])
#label.shape = 4*10 , 为4行10列的tensor,参数dim=1指的是在label的第一个维度取值
#也就是在label的列取值, idx指的就是dim=1的索引,也就是label列索引
#如果dim=0,则idx指的就是dim=0的索引,也就是label行索引
#out的size跟idx的size一致
#idx(0,0)的值为0,由于dim=1,所以应该取出label第0行,第0列的值,所以out(0,0) = 100
#idx(0,1)的值为5,由于dim=1,所以应该取出label第0行,第5列的值,所以out(0,1) = 105
#idx(0,2)的值为7,由于dim=1,所以应该取出label第0行,第7列的值,所以out(0,2) = 107
## 接下来,进入到idx的第二行,所以,label也应该进入到第二行,out也是如此。
#idx(1,0)的值为3,由于dim=1,所以应该取出label第1行,第3列的值,所以out(1,0) = 103
#idx(1,1)的值为5,由于dim=1,所以应该取出label第1行,第1列的值,所以out(1,1) = 105
#idx(1,2)的值为6,由于dim=1,所以应该取出label第1行,第6列的值,所以out(1,0) = 106
#看出规律了吗,因为dim=1,idx的值就是要取的label的列索引,idx的行号,就是要取的label的行索引
##想必看到这里,你已经明白了gather的作用
##我们再来测试一下dim=0

torch.gather(label,dim = 0,index = idx.long()):
#报错如下:
#RuntimeError                              Traceback (most #recent call last)
#~/Desktop/Study/手写数字.py in 
#----> 346 torch.gather(label,dim = 0,index = idx.long())
#
#RuntimeError: index 5 is out of bounds for dimension 0 with size 4

##哦豁,竟然报错了,因为我们的idx中,有大于4的值,为什么是大于4呢,因为label的第一个维度为4,理解了吧
#我们重新换个idx1来试下:
idx1  = torch.tensor([[0,1,0,3,2,1,0,1]]) 

label:
		tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
		        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
		        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
		        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])

torch.gather(label,dim = 0,index = idx1.long())
#输出:
out =
    tensor([[100, 101, 102, 103, 104, 105, 106, 107]])
#还是老规矩,out的shape,一定是和idx1一样的
#idx1中没有大于4的值,因为label的shape为4*10,所以当dim=0时,idx中的值表示的就是label的行号了,所以必须小于4.
#接下来我们来分析为什么out是这样的结果:
#idx1.shape = (1,8)
#idx1(0,0)的值为0,dim=0,所以应该取出label的第0行,第0列的值,所以out(0,0) = 100
#idx1(0,1)的值为1,dim=0,所以应该取出label的第1行,第1列的值,所以out(0,1) = 101
#idx1(0,2)的值为0,dim=0,所以应该取出label的第0行,第2列的值,所以out(0,2) = 102
#idx1(0,3)的值为3,dim=0,所以应该取出label的第3行,第3列的值,所以out(0,3) = 103
#idx1(0,4)的值为2,dim=0,所以应该取出label的第2行,第4列的值,所以out(0,4) = 104
#看出来了吗,由于dim=0,现在idx1的值,就是要取的label的行号,idx1的列索引,就是要取的label的列号
#如果到这里还不理解的话,可以再更换下idx试试。

看到这里,如果让你恍然大悟的话,点个赞吧😁

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ToTensor

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值