Pytorch函数——torch.gather详解

在学习强化学习时,顺便复习复习pytorch的基本内容,遇到了 torch.gather()函数,参考图解PyTorch中的torch.gather函数 - 知乎 (zhihu.com)进行解释。

pytorch官网对函数给出的解释:

image.png

即input是一个矩阵,根据dim的值,将index的值替换到不同的维度的索引,当dim为0时,index替代i的值,成为第0维度的索引。

输入和输出的矩阵形式相同。

例子:首先我们生成3×3的矩阵,明确行索引的概念,第0行指的是[3,4,5],第0列指的是[[3] [6] [9]]

import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

index为行向量且dim=0时

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
tensor([[9, 7, 5]])

当dim=0时,替换第0维度。由于input为二维列表,因此第0维度指的是选择第几行的维度,即行索引所在的维度,替换了i的索引,为input[index[i][j]] [j]

那么我们会输出tensor([[ input[2][j] input[1][j] input[0][j] ]]),那么j如何获得呢?从index of index中拿到,index每一个元素的索引为(0,0) (0,1) (0,2),取j,则为0,1,2,那么输出则为tensor([[ input[2][0] input[1][1] input[0][2] ]]),即

tensor([[9, 7, 5]])

输入行向量index,且dim=1

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5, 4, 3]])

维度为1,则替换列索引的值,那么输出为tensor([[ input[i][2] input[i][1] input[i][0] ]]),index每一个元素的索引为(0,0) (0,1) (0,2),i均为1,那么tensor([[ input[0][2] input[0][1] input[0][0] ]])

输入为行向量,dim=0

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
tensor([[5],
        [7],
        [9]])

维度为0,则替换行索引,且输出与输入的格式相同,为

tensor([input[2][j],
        input[1][j],
        input[0][j]])

index每一个元素的索引为(0,0) (1,0) (2,0),j对应的值为0,0,0,则

tensor([input[2][0],
        input[1][0],
        input[0][0]])

输入为行向量,dim=1

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5],
        [7],
        [9]])

维度为1,则替换列索引,且输出与输入的格式相同,为

tensor([input[i][2],
        input[i][1],
        input[i][0]])

index每一个元素的索引为(0,0) (1,0) (2,0),i对应的值为0,1,2,则

tensor([input[0][2],
        input[1][1],
        input[2][0]])

输入为二维矩阵,且dim=1

index = torch.tensor([[0, 2],
                      [1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[3, 5],
        [7, 8]])

维度为1,则替换列索引,且输出与输入的格式相同,为

tensor([[input[i][0], input[i][2]],
        [input[i][1], input[i][2]]])

替换为行索引后,可得:

tensor([[input[0][0], input[0][2]],
        [input[1][1], input[1][2]]])

在强化学习中的应用

在PyTorch官网DQN页面的代码中,i是state,j是a

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)

我们使用dim=1action_batch 将获得的动作列表替换为列索引,即可获得每个state下该动作的动作价值。

  • 9
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
torch.flatten是PyTorch中的一个函数,用于将输入张量展平为一维。它接受输入张量和可选的起始维度和结束维度参数。在引用\[1\]中的第一个案例中,torch.flatten被用于将形状为(64, 3, 32, 32)的图像张量展平为形状为(64, 3072)的一维张量。在引用\[2\]中的第二个案例中,如果将torch.flatten替换为torch.reshape,结果将是将形状为(64, 3, 32, 32)的图像张量重新调整为形状为(1, 1, 1, 3072)的张量。在引用\[3\]中的案例中,torch.flatten被用于将形状为(3, 3, 3)的张量部分展平为一维,具体来说是从第0维到第1维。 #### 引用[.reference_title] - *1* *2* [深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别](https://blog.csdn.net/qq_42233059/article/details/126663501)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [PyTorch基础(15)-- torch.flatten()方法](https://blog.csdn.net/dongjinkun/article/details/121479361)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值