pytorch的gather函数的一些粗略的理解

pytorch的gather函数的一些理解

先给出官方文档的解释,我觉得官方的文档写的已经很清楚了,四个参数分别是input,dim,index,out,输出的tensor是以index为大小的tensor。
在这里插入图片描述
其中,这就是最关键的定义

out[i][j][k] = tensor[index[i][j][k]][j][k]  # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k]  # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]]  # dim=3

主要解释一下dim,dim=0的时候,把index的元素放入进行索引,有一点需要注意的是,参数index的tensor格式是除了第1维也就是行那一维之外,其他维的格式需与input保持一致!下面给个例子

import torch 

a = torch.arange(0, 16).view(4,4)

index = torch.LongTensor([[0,1,2,3]])

b = a.gather(0, index)
print(a)
print(index)
print(b)

#形象的理解就是在每一列的第index[]上进行索引
for j in range(4):
    print(a[index[0][j]][j].item())
--------------------------------------------------------------------
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[0, 1, 2, 3]])
tensor([[ 0,  5, 10, 15]])
0
5
10
15

dim = 1的时候,把index的元素放入进行索引,有一点需要注意的是,参数index的tensor格式是除了第2维也就是列那一维之外,其他维的格式需与input保持一致!下面给个例子

import torch 

a = torch.arange(0, 16).view(4,4)

index = torch.LongTensor([[0],[1],[2],[3]])

b = a.gather(1, index)
print(a)
print(index)
print(b)

#形象的理解就是在每一行的第index[]列上进行索引
for j in range(4):
    print(a[j][index[j][0]].item())
--------------------------------------------------------------------
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[0],
        [1],
        [2],
        [3]])
tensor([[ 0],
        [ 5],
        [10],
        [15]])
0
5
10
15

本人对矩阵的一些概念还有一些模糊不清,以上就是我的一些理解,希望有大佬可以一起交流一下,pytorch 的张量一开始很难处理清楚,还需慢慢来。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值