【Pytorch】torch.gather

定义:从原tensor中获取指定dim和指定index的数据,生成新的tensor

输入

import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
tensor_0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

Example-1

dim: 0

index_1 = torch.tensor([[2]])
tensor_1 = tensor_0.gather(0, index_1)
print("tensor_1", tensor_1)

index_2 = torch.tensor([[2, 1]])
tensor_2 = tensor_0.gather(0, index_2)
print("tensor_2", tensor_2)

index_3 = torch.tensor([[2, 1, 0]])
tensor_3 = tensor_0.gather(0, index_3)
print("tensor_3", tensor_3)
tensor_1 tensor([[9]])   # 对应下标[(2,0)]
tensor_2 tensor([[9, 7]])   # 对应下标[(2,0), (1,1)]
tensor_3 tensor([[9, 7, 5]])  # 对应下标[(2,0), (1,1), (0,2)]

理解
d i m = 0 , i n d e x = t o r c h . t e n s o r ( [ [ 2 , 1 , 0 ] ] ) dim=0,index=torch.tensor([[2, 1, 0]]) dim=0,index=torch.tensor([[2,1,0]]):表示将取出下标 [ ( 0 , 0 ) , ( 0 , 1 ) , ( 0 , 2 ) ] [(0,0), (0,1), (0,2)] [(0,0),(0,1),(0,2)] d i m = 1 dim=1 dim=1维度不变, d i m = 0 dim=0 dim=0维度根据 i n d e x index index修改为 [ ( 2 , 0 ) , ( 1 , 1 ) , ( 0 , 2 ) ] [(2,0), (1,1), (0,2)] [(2,0),(1,1),(0,2)]

Example-2

dim: 0

index_4 = torch.tensor([[2]])
tensor_4 = tensor_0.gather(1, index_4)
print("tensor_4", tensor_4)

index_5 = torch.tensor([[2, 1]])
tensor_5 = tensor_0.gather(1, index_5)
print("tensor_5", tensor_5)

index_6 = torch.tensor([[2, 1, 0]])
tensor_6 = tensor_0.gather(1, index_6)
print("tensor_6", tensor_6)
tensor_4 tensor([[5]])  # 对应下标[(0,2)]
tensor_5 tensor([[5, 4]])  # 对应下标[(0,2), (0,1)]
tensor_6 tensor([[5, 4, 3]])  # 对应下标[(0,2), (0,1), (0, 0)]

理解
d i m = 1 , i n d e x = t o r c h . t e n s o r ( [ [ 2 , 1 , 0 ] ] ) dim=1,index=torch.tensor([[2, 1, 0]]) dim=1,index=torch.tensor([[2,1,0]]):表示将取出下标 [ ( 0 , 0 ) , ( 0 , 1 ) , ( 0 , 2 ) ] [(0,0), (0,1), (0,2)] [(0,0),(0,1),(0,2)] d i m = 0 dim=0 dim=0维度不变, d i m = 1 dim=1 dim=1维度根据 i n d e x index index修改为 [ ( 0 , 2 ) , ( 0 , 1 ) , ( 0 , 0 ) ] [(0,2), (0,1), (0,0)] [(0,2),(0,1),(0,0)]

Example-3

dim: 0

index_7 = torch.tensor([[0, 2], [1, 2], [0, 2]])
tensor_7 = tensor_0.gather(1, index_7)
tensor_7
tensor([[ 3,  5],
        [ 7,  8],
        [ 9, 11]])

理解
d i m = 1 , i n d e x = t o r c h . t e n s o r ( [ [ 0 , 2 ] , [ 1 , 2 ] , [ 0 , 2 ] ] ) dim=1,index=torch.tensor([[0, 2], [1, 2], [0, 2]]) dim=1,index=torch.tensor([[0,2],[1,2],[0,2]]):表示将取出下标 [ [ ( 0 , 0 ) , ( 0 , 1 ) ] , [ ( 1 , 0 ) , ( 1 , 1 ) ] , [ ( 2 , 0 ) , ( 2 , 1 ) ] ] [[(0,0), (0,1)], [(1, 0) , (1, 1)], [(2, 0), (2, 1)]] [[(0,0),(0,1)],[(1,0),(1,1)],[(2,0),(2,1)]] d i m = 0 dim=0 dim=0维度不变, d i m = 1 dim=1 dim=1维度根据 i n d e x index index修改为 [ [ ( 0 , 0 ) , ( 0 , 2 ) ] , [ ( 1 , 1 ) , ( 1 , 2 ) ] , [ ( 2 , 0 ) , ( 2 , 2 ) ] ] [[(0,0), (0,2)], [(1, 1) , (1, 2)], [(2, 0), (2, 2)]] [[(0,0),(0,2)],[(1,1),(1,2)],[(2,0),(2,2)]]

  • tensor
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
  • 取值结果
修改前下标index修改后下标对应的数值
[(0,0), (0,1)][0, 2][(0, 0), (0, 2)][ 3, 5]
[(1, 0) , (1, 1)][1, 2][(1, 1) , (1, 2)][ 7, 8]
[(2, 0), (2, 1)][0, 2][(2, 0), (2, 2)][ 9, 11]

参考链接

图解PyTorch中的torch.gather函数

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值