torch.gather函数的简单理解与使用

功能:根据索引来对高维tensor进行选择
要求:

  • input tensor 与 index 的 dim一致
  • index.shape < input.shape
torch.gather(input, dim, index) → Tensor

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
import torch

# [[1,2,3],
#  [4,5,6],
#  [7,8,9]]
input = torch.range(1, 9).view(3, 3)

示例1

#====================== dim=0 索引======================#
# 表示提供采样索引维度的是dim=0
dim = 0		
# index的维度为 [1, 3], 说明输出维度也为 [1, 3], index每个元素的索引为 [[(0, 0), (0, 1), (0, 2)]]
index = torch.tensor([[2, 1, 0]])
# 用index 来替换index的索引列表中的dim=0 得到: [[(2, 0), (1, 1), (0, 2)]]
output = torch.gather(input, dim, index)
# 将input索引为 (2, 0), (1, 1), (0, 3) 取出来就是 [[7, 5, 3]] 	

在这里插入图片描述


示例2

#======================== dim=1 索引=====================#
# 表示提供的采样索引维度dim=1
dim = 1
# index的维度为 (1, 3), 也就是index每个元素的索引为    [[(0, 0), (0, 1), (0, 2)]], 
index = torch.tensor([[2, 1, 0]])
# 用index的取值来替代 index的索引列表中的dim=1的元素得:[[(0,2) (0,1) (0,0)]]
# 将input索引为 (0,2) (0,1) (0,0) 取出来就是[[3, 2, 1]]
optput = torch.gather(input, dim, index)

在这里插入图片描述


示例3

#=============================================#
dim = 1	# 表示采样索引为1
index = torch.tensor([[2],
		 			  [1],
		 			  [0]])
# index的索引为[(0, 0),
			   (1, 0),
			   (2, 0)]
# 使用index的其余维索引来补全后得到:
'''
 [(0, 2),
  (1, 1),
  (2, 0)]
'''
# 对input索引
output = torch.gather(input, dim, index)
'''
[[3],
 [5],
 [7]]
'''

在这里插入图片描述


示例4

#====================== 多维index =====================#
dim = 1
index = torch.tensor([[0, 2], 
                      [1, 2]])
# index 的索引为 [[(0,0), (0,1)], 
# 				 [(1,0), (1,1)]]
# 用除了1维以外的索引将index补全得:
# [[(0,0), (0,2)], 
#  [(1,1), (1,2)]]

# 对input索引
output = torch.gather(input, dim, index)		  
#[[0, 3]
# [5, 6]]

在这里插入图片描述

更高维度的gather索引也是如此,先生成index每个元素的索引,再用index的值来替代dim维度的索引值,最后按照索引值到input中索引得到output

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CV科研随想录

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值