pytorch的chunk、gather函数

pytorch的chunk、gather函数

一、chunk函数

1.函数语法:

torch.chunk
torch.chunk(tensor, chunks, dim=0)
在给定维度()上将输入张量进行分块儿。 
参数:
tensor (Tensor) – 待分块的输入张量
chunks (int) – 分块的个数
dim (int) – 沿着此维度进行分块

2.举例1

import torch
torch_1 = torch.randn([2,3])  #插入一个2行3列的张量
torch_c = torch.chunk(torch_1, chunks=2, dim=0) 
#dim=0表示按行进行分块
print(torch_1)
print(torch_c)  #输出分块结果

输出结果如下:
在这里插入图片描述
3.举例2

import torch
torch_1 = torch.randn([2,3])  #插入一个2行3列的张量
torch_c = torch.chunk(torch_1, chunks=2, dim=1) 
#dim=1表示按列进行分块
print(torch_1)
print(torch_c)  #输出分块结果

在这里插入图片描述
按列拆分后,前两列变为一个新的张量,最后一列是一个张量。

二、gather函数

1.gather函数语法

torch.gather(input, dim, index, out=None) → Tensor
沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
对一个3维张量,输出可以定义为:
out[i][j][k] = tensor[index[i][j][k]][j][k]  # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k]  # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]]  # dim=3

这是官网文档给出的解释,gather函数类似于numpy函数中按照索引取数组中的某项或某几项,在gather函数中:

input 为张量对象,需从中取出索引对应的数值
dim   为维度
index 为待取索引

2.举例1

从张量torch_1中,取一维数组

torch_1=torch.tensor([[1,2],[3,4]])
torch_2=torch.gather(torch_1, dim=0, index=torch.tensor([[0, 1]]))
print(torch_1)
print(torch_2)

在这里插入图片描述

  • 第一行是创建一个2*2的张量,数据如图所示是1,2,3,4
  • 第二行按索引取数据,对象为torch_1,dim=0 表示按行,index=[0,1]
    • output.shape = index.shape 确定最后输出的output的shape必须与index的相同,这里是12的tensor,那么output必须也是12的tensor,那输出就是torch.tensor([[?,?]]
    • 对output所有值的索引,按shape方式排出来,也就是[[(0,0),(0,1)]]
    • 还是对output,拿index里的值替换上面dim指定位置,dim=0替换行(第一项),dim=1即替换列(第二项)。此处dim=0,index里的值是[0,1],于是将[[(0,0),(0,1)]]替换为[[(0,0),(1,1)]]
    • 按这个索引[[(0,0),(1,1)]]获取torch_1相应位置的值,填进去就好了,得到torch.tensor([[1,4]])

3.举例2

从张量torch_1中,取二维数组

torch_1=torch.tensor([[1,2],[3,4]])
torch_2=torch.gather(torch_1, dim=1, index=torch.tensor([[0,1],[1,0]])
print(torch_1)
print(torch_2)

输出结果为:
在这里插入图片描述

  • 第一行是创建一个2*2的张量,数据如图所示是1,2,3,4
  • 第二行按索引取数据,对象为torch_1,dim=0 表示按行,index=([[0,1],[1,0]])
    -output.shape = index.shape 确定最后输出的output的shape必须与index的相同,这里是22的tensor,那么output必须也是22的tensor,那输出就是torch.tensor([[?,?],[?,?]]
    • 对output所有值的索引,按shape方式排出来,也就是
      ([[(0,0),(0,1)],
      [(1,0),(1,1)]])
    • 还是对output,拿index里的值替换上面dim指定位置,dim=0替换行(第一项),dim=1即替换列(第二项)。此处dim=1,即替换列,index=([[0,1],[1,0]])
      ([[(0,0),(0,1)],
      [(1,0),(1,1)]])替换后,结果为
      ([[(0,0),(0,1)],
      [(1,1),(1,0)]])
    • 按这个索引([[(0,0),(0,1)],[(1,1),(1,0)]])获取torch_1相应位置的值,填进去就好了,得到torch.tensor(([[1,2],[4,3]]))

总结

对torch.gather

  • index的shape等于output的shape,按shape依次写出索引A
  • 将索引A对应的dim位置值进行替换,用输入index的值进行替换,替换结果为索引B
  • 原tensor的索引B对应的值就是output结果
  • 28
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值