torch.gather() 函数理解

本文主要参考 https://zhuanlan.zhihu.com/p/352877584 ,在对 torch.gather() 理解之后,总结了一套比较好用的计算方法,下面直接来看例子。

>>> import torch
>>> tensor_0 = torch.arange(3, 12).view(3, 3)

Test 1

>>> index = torch.tensor([[2, 1, 0]])
>>> tensor_1 = tensor_0.gather(0, index)

# tensor_0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
# tensor_1 
tensor([[9, 7, 5]])

首先根据 index 的 shape = (1, 3),我们将这个水平长条对应到 tensor_0 的 [[3, 4, 5]] 上(尽量往左上角放)。gather() 函数中第一个维度参数为 0,表明沿着竖直方向根据 index 替换数值,例如 [[3, 4, 5]] 中的 3,由于其对应位置 index 为 [[2, 1, 0]] 中的 2,因此将 3 替换成该的 9。以此类推,4 换成 7,5 还是 5。

Test 2

>>> index = torch.tensor([[2, 1, 0]])
>>> tensor_1 = tensor_0.gather(1, index)

# tensor_0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
# tensor_1 
tensor([[5, 4, 3]])

首先根据 index 的 shape = (1, 3),我们将这个水平长条对应到 tensor_0 的 [[3, 4, 5]] 上。gather() 函数中第一个维度参数为 1,表明沿着水平方向根据 index 替换数值,例如 [[3, 4, 5]] 中的 3,由于其对应位置 index 为 [[2, 1, 0]] 中的 2,因此将 3 替换成该的 5。以此类推,4 还是 4,5 换成 3。

Test 3

>>> index = torch.tensor([[2, 1, 0]]).t()  # 表示一列 [[2], [1], [0]]
>>> tensor_1 = tensor_0.gather(1, index)

# tensor_0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
# tensor_1
tensor([[5],
        [7],
        [9]])

首先根据 index 的 shape = (3, 1),我们将这个竖直长条对应到 tensor_0 的 [[3], [6], [9]] 上。gather() 函数中第一个维度参数为 1,表明沿着水平方向根据 index 替换数值,例如 [[3], [6], [9]] 中的 3,由于其对应位置 index 为 [[2], [1], [0]] 中的 2,因此将 3 替换成该的 5。以此类推,6 换成 7,9 还是 9。

Test 4

>>> index = torch.tensor([[0, 2], 
                          [1, 2]])
>>> tensor_1 = tensor_0.gather(1, index)

# tensor_0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
# tensor_1
tensor([[3, 5],
        [7, 8]])

首先根据 index 的 shape = (2, 2),我们将这个水平长条对应到 tensor_0 的 [[3, 4], [6, 7]] 上。gather() 函数中第一个维度参数为 1,表明沿着水平方向根据 index 替换数值,例如 [[3, 4], [6, 7]] 中的 4,由于其对应位置 index 为 [[0, 2], [1, 2]] 中的 2,因此将 4 替换成该的 5。以此类推,同一行的 3 还是 3,第二行的 6 换成7,7 换成 8。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值