torch.gather()总结

torch.gather沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.

1. 二维情况下

(1)case1: dim=0

import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)

输出

tensor_0: tensor([[ 3,  4,  5],
       			  [ 6,  7,  8],
       			  [ 9, 10, 11]])
       			  
tensor_1 tensor([[9, 7, 5]])
# note:  dim=0从列里面选,【9】是第一列中第2个数,【7】是第二列第1个数,【5】是第三列第0个数

(2)case2: dim=1

tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2],[1],[0]])
tensor_1 = tensor_0.gather(1, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)

输出

tensor_0: tensor([[ 3,  4,  5],
		       	  [ 6,  7,  8],
		          [ 9, 10, 11]])
tensor_1 tensor([[5],
		      	 [7],
		    	 [9]])
# note: dim=1, 从行里取【5】是第一行第二个数,【7】是第二行第1个数,【9】是第三行第0个数

case3: 一行中取多个数

tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2,1,0],[1,1,1],[0,1,0]])
tensor_1 = tensor_0.gather(1, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)

输出

tensor_0: tensor([[ 3,  4,  5],
		          [ 6,  7,  8],
		          [ 9, 10, 11]])
tensor_1 tensor([[ 5,  4,  3],
		       	[ 7,  7,  7],
		       	[ 9, 10,  9]])

Case4: 一列中取多个数

    tensor_0 = torch.arange(3, 12).view(3, 3)
    index = torch.tensor([[2,1,0],[1,1,1],[0,1,0]])
    tensor_1 = tensor_0.gather(0, index)
    print("tensor_0:", tensor_0)
    print("tensor_1", tensor_1)

输出

tensor_0: tensor([[ 3,  4,  5],
		          [ 6,  7,  8],
		          [ 9, 10, 11]])
tensor_1 tensor([[9, 7, 5],
		       	[6, 7, 8],
		        [3, 7, 5]])

2. 三维情况

case1: dim=1

    a = torch.randint(0, 30, (2, 3, 5))
    index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
    b = torch.gather(a, 1, index)
    print("a:", a)
    print("b:", b)

输出:

a: tensor([[[13,  1, 25, 18, 28],
         	[24, 19,  5, 25, 11],
         	[13, 13, 20,  9, 22]],

           [[22, 18, 12,  9,  1],
         	[ 6, 11, 23, 11, 29],
        	[15,  9,  8, 29,  6]]])
        	
b: tensor([[[13, 19, 20, 18, 22],
         	[13,  1, 25, 18, 28],
         	[24, 19,  5, 25, 11]],

           [[ 6,  9,  8, 29,  6],
         	[22, 18, 12,  9,  1],
         	[15,  9,  8, 29,  6]]])
         	
 # note: dim=1从列中取

case2: dim=2

    a = torch.randint(0, 30, (2, 3, 5))
    index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
    b = torch.gather(a, 2, index)
    print("a:", a)
    print("b:", b)

输出

a: tensor([[[ 0, 19,  3, 20, 29],
         [ 4,  2,  1,  8, 13],
         [16, 15, 13, 29, 10]],

        [[25, 18, 16,  0,  6],
         [ 3,  4, 13, 23, 19],
         [ 7, 21, 28, 17, 11]]])
         
b: tensor([[[ 0, 19,  3,  0,  3],
         [ 4,  4,  4,  4,  4],
         [15, 15, 15, 15, 15]],

        [[18, 16, 16, 16, 16],
         [ 3,  3,  3,  3,  3],
         [28, 28, 28, 28, 28]]])
#  dim=2 从行里取数

case3: dim=0

    a = torch.randint(0, 30, (2, 3, 5))
    index = torch.LongTensor([[[0,1,1,0,1],
                          [0,1,1,1,1],
                          [1,1,1,1,1]],
                        [[1,0,0,0,0],
                         [0,0,0,0,0],
                         [1,1,0,0,0]]])
    b = torch.gather(a, 0, index)
    print("a:", a)
    print("b:", b)

输出

a: tensor([[[ 9,  3, 10, 19,  4],
         	[26, 19, 20,  9, 28],
         	[ 5, 21, 29, 26, 24]],

           [[10,  2, 11, 29, 26],
         	[20, 25, 17, 11, 16],
         	[ 4, 17, 27, 17, 29]]])
         
b: tensor([[[ 9,  2, 11, 19, 26],
         	[26, 25, 17, 11, 16],
         	[ 4, 17, 27, 17, 29]],

           [[10,  3, 10, 19,  4],
         	[26, 19, 20,  9, 28],
         	[ 4, 17, 29, 26, 24]]])
# dim = 0时,索引代表在第几页取数,取数的位置为索引i所在的坐标,如上:index[0][0][0]=0,表示取a中第0页(0,0)的数9,index[0][0][1]=1表示取第1页的(0,1)坐标的数3

在这里插入图片描述

总结

  • index的维数必须与输入维数相同,输入为2维矩阵,index也必须为2维矩阵
  • 在二维矩阵中dim=0 表示列,dim=1表示行,三维矩阵中,dim=0表示页,dim=1表示列,dim=2表示行
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值