pytorch的chunk、gather函数
一、chunk函数
1.函数语法:
torch.chunk
torch.chunk(tensor, chunks, dim=0)
在给定维度(轴)上将输入张量进行分块儿。
参数:
tensor (Tensor) – 待分块的输入张量
chunks (int) – 分块的个数
dim (int) – 沿着此维度进行分块
2.举例1
import torch
torch_1 = torch.randn([2,3]) #插入一个2行3列的张量
torch_c = torch.chunk(torch_1, chunks=2, dim=0)
#dim=0表示按行进行分块
print(torch_1)
print(torch_c) #输出分块结果
输出结果如下:
3.举例2
import torch
torch_1 = torch.randn([2,3]) #插入一个2行3列的张量
torch_c = torch.chunk(torch_1, chunks=2, dim=1)
#dim=1表示按列进行分块
print(torch_1)
print(torch_c) #输出分块结果
按列拆分后,前两列变为一个新的张量,最后一列是一个张量。
二、gather函数
1.gather函数语法
torch.gather(input, dim, index, out=None) → Tensor
沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
对一个3维张量,输出可以定义为:
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3
这是官网文档给出的解释,gather函数类似于numpy函数中按照索引取数组中的某项或某几项,在gather函数中:
input 为张量对象,需从中取出索引对应的数值
dim 为维度
index 为待取索引
2.举例1
从张量torch_1中,取一维数组
torch_1=torch.tensor([[1,2],[3,4]])
torch_2=torch.gather(torch_1, dim=0, index=torch.tensor([[0, 1]]))
print(torch_1)
print(torch_2)
- 第一行是创建一个2*2的张量,数据如图所示是1,2,3,4
- 第二行按索引取数据,对象为torch_1,dim=0 表示按行,index=[0,1]
- output.shape = index.shape 确定最后输出的output的shape必须与index的相同,这里是12的tensor,那么output必须也是12的tensor,那输出就是torch.tensor([[?,?]]
- 对output所有值的索引,按shape方式排出来,也就是[[(0,0),(0,1)]]
- 还是对output,拿index里的值替换上面dim指定位置,dim=0替换行(第一项),dim=1即替换列(第二项)。此处dim=0,index里的值是[0,1],于是将[[(0,0),(0,1)]]替换为[[(0,0),(1,1)]]
- 按这个索引[[(0,0),(1,1)]]获取torch_1相应位置的值,填进去就好了,得到torch.tensor([[1,4]])
3.举例2
从张量torch_1中,取二维数组
torch_1=torch.tensor([[1,2],[3,4]])
torch_2=torch.gather(torch_1, dim=1, index=torch.tensor([[0,1],[1,0]])
print(torch_1)
print(torch_2)
输出结果为:
- 第一行是创建一个2*2的张量,数据如图所示是1,2,3,4
- 第二行按索引取数据,对象为torch_1,dim=0 表示按行,index=([[0,1],[1,0]])
-output.shape = index.shape 确定最后输出的output的shape必须与index的相同,这里是22的tensor,那么output必须也是22的tensor,那输出就是torch.tensor([[?,?],[?,?]]- 对output所有值的索引,按shape方式排出来,也就是
([[(0,0),(0,1)],
[(1,0),(1,1)]]) - 还是对output,拿index里的值替换上面dim指定位置,dim=0替换行(第一项),dim=1即替换列(第二项)。此处dim=1,即替换列,index=([[0,1],[1,0]])
([[(0,0),(0,1)],
[(1,0),(1,1)]])替换后,结果为
([[(0,0),(0,1)],
[(1,1),(1,0)]]) - 按这个索引([[(0,0),(0,1)],[(1,1),(1,0)]])获取torch_1相应位置的值,填进去就好了,得到torch.tensor(([[1,2],[4,3]]))
- 对output所有值的索引,按shape方式排出来,也就是
总结
对torch.gather
- index的shape等于output的shape,按shape依次写出索引A
- 将索引A对应的dim位置值进行替换,用输入index的值进行替换,替换结果为索引B
- 原tensor的索引B对应的值就是output结果