TORCH.GATHER函数的简单理解

一,版权声明

        版权声明:本文为weixin_44291388原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

        本文链接:torch.gather函数的简单理解_乾巽的博客-CSDN博客_torch。gatherhttps://blog.csdn.net/weixin_44291388/article/details/104139447

二,官方文档

        pytorch官网关于torch.gather的文档:torch.gather — PyTorch 1.11.0 documentationhttps://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather

三,理解

接下来是原文作者理解部分:

import torch
# 先看看out 和 index 都是二维数组的情况
# out[i][j] = tensor[index[i][j]][j]  # dim=0
# out[i][j] = tensor[[i][index[i][j]]  # dim=1
 
t = torch.Tensor([[1,2],[3,4]])
# t = 1 2
#     3 4
index = torch.LongTensor([[0,0],[1,0]])
# index = 0 0
#         1 0
print(torch.gather(t, 1, index))  #此时dim = 1
# 输出  1 1
#       4 3
# 输出的结果的size =  index.size()
 
# 讲解过程
# index[0][0] = 0
# index[0][1] = 0
# index[1][0] = 1
# index[1][1] = 0
# dim = 1
# out[0][0] = tensor[[0]index[0][0]] == tensor[0][0] == 1
# out[0][1] = tensor[[0]index[0][1]] == tensor[0][0] == 1
# out[1][0] = tensor[[1]index[1][0]] == tensor[1][1] == 4
# out[1][1] = tensor[[1]index[1][1]] == tensor[1][0] == 3
 
 
#例二 dim = 1
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 1])
print(y.view(-1,1))  # 2行1列
 
print(y_hat.gather(1, y.view(-1, 1)))
# 0.1
# 0.2
 
# y[0][0] = y_hat[[0][y[0][0]] == y_hat[0][0] == 0.1
# y[1][0] = y_hat[[1][y[1][0]] == y_hat[1][1] == 0.1
                
 
#例三  dim = 0
t = torch.Tensor([[1,2],[3,4]])
# t = 1 2
#     3 4
index = torch.LongTensor([[0,0],[1,0]])
# index = 0 0
#         1 0
print(torch.gather(t, 0, index))  #此时dim = 1
# 输出  1 2
#       3 2
# index[0][0] = 0
# index[0][1] = 0
# index[1][0] = 1
# index[1][1] = 0
# dim = 0
# out[0][0] = tensor[[index[0][0]][0]] == tensor[0][0] == 1
# out[0][1] = tensor[[index[0][1]][1]] == tensor[0][1] == 2
# out[1][0] = tensor[[index[1][0]][0]] == tensor[1][1] == 3
# out[1][1] = tensor[[index[1][1]][1]] == tensor[0][1] == 2

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值