torch.gather

torch.gather

  1. 参数:torch.gather(input, dim, index, out=None)
  2. 注意:
    · output.shape == index.shape
    · index.dtype == torch.LongTensor
    · index不能为负数,-1也会报错
  3. 分析:对于3维张量
    · dim=0, output(i,j,k) = input(index(i,j,k), j, k)
    · dim=1, output(i,j,k) = input(i, index(i,j,k), k)
    · dim=2, output(i,j,k) = input(i, j, index(i,j,k))
  4. 例子1: 二维
    In [1]: x = torch.arange(4).reshape(2,2)
    In [2]: x
    Out[2]: tensor([[0, 1],
        			[2, 3]])     			
    In [3]: gather_index = torch.LongTensor([[0, 1],[0, 1]])
    In [4]: torch.gather(x, dim=1, index=gather_index)
    Out[4]: tensor([[0, 1],
        	 	    [2, 3]])
    
  5. 例子2: 三维
    In [1]: x = torch.arange(12).reshape(2,3,2)
    In [2]: x
    Out[2]: tensor([[[0, 1],
         			 [2, 3],
         		     [4, 5]],
           		    [[6, 7],
            		 [8, 9],
            		 [10, 11]]])	
    In [3]: gather_index = gather_index.unsqueeze(-1).expand(x.shape[0],2,x.shape[-1])
    In [4]: gather_index
    Out[4]:tensor([[[0, 0],
         			[1, 1]],
        		   [[0, 0],
         			[1, 1]]])
    In [5]: torch.gather(x, dim=1, index=gather_index)
    Out[5]: tensor([[[0, 1],
         			 [2, 3]],
                    [[6, 7],
                     [8, 9]]])
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值