pytorch的gather函数的一些粗略的理解

pytorch的gather函数的一些理解

先给出官方文档的解释,我觉得官方的文档写的已经很清楚了,四个参数分别是input,dim,index,out,输出的tensor是以index为大小的tensor。
在这里插入图片描述
其中,这就是最关键的定义

out[i][j][k] = tensor[index[i][j][k]][j][k]  # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k]  # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]]  # dim=3

主要解释一下dim,dim=0的时候,把index的元素放入进行索引,有一点需要注意的是,参数index的tensor格式是除了第1维也就是行那一维之外,其他维的格式需与input保持一致!下面给个例子

import torch 

a = torch.arange(0, 16).view(4,4)

index = torch.LongTensor([[0,1,2,3]])

b = a.gather(0, index)
print(a)
print(index)
print(b)

#形象的理解就是在每一列的第index[]上进行索引
for j in range(4):
    print(a[index[0][j]][j].item())
--------------------------------------------------------------------
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[0, 1, 2, 3]])
tensor([[ 0,  5, 10, 15]])
0
5
10
15

dim = 1的时候,把index的元素放入进行索引,有一点需要注意的是,参数index的tensor格式是除了第2维也就是列那一维之外,其他维的格式需与input保持一致!下面给个例子

import torch 

a = torch.arange(0, 16).view(4,4)

index = torch.LongTensor([[0],[1],[2],[3]])

b = a.gather(1, index)
print(a)
print(index)
print(b)

#形象的理解就是在每一行的第index[]列上进行索引
for j in range(4):
    print(a[j][index[j][0]].item())
--------------------------------------------------------------------
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[0],
        [1],
        [2],
        [3]])
tensor([[ 0],
        [ 5],
        [10],
        [15]])
0
5
10
15

本人对矩阵的一些概念还有一些模糊不清,以上就是我的一些理解,希望有大佬可以一起交流一下,pytorch 的张量一开始很难处理清楚,还需慢慢来。

已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 深蓝海洋 设计师:CSDN官方博客 返回首页