在 Tensorflow 和 PyTorch 中,
gather
函数用于按照指定的索引从输入张量中聚合元素并返回新的张量。
具体来说,该函数的作用和用法如下:
TensorFlow 中的 tf.gather()
函数:
tf.gather(params, indices, axis=None, batch_dims=0, name=None)
其中,params
表示输入的张量;indices
表示需要聚合的索引,可以是常量列表或张量;axis
是指定聚合维度的整数,如果不传递,则默认为 0;batch_dims
是指定批次维度数的整数,通常在将多个样本进行聚合时使用;name
是可选的操作名称。
例如,对于张量 params = [[1, 2], [3, 4], [5, 6], [7, 8]]
,要求按照索引 [0, 2, 3]
在第一维度上进行聚合,则可以使用以下代码:
import tensorflow as tf
params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = [0, 2, 3]
output = tf.gather(params, indices, axis=0)
print(output)
# Output: [[1 2]
# [5 6]
# [7 8]]
PyTorch 中的 torch.gather()
函数:
torch.gather(input, dim, index, out=None)
其中,input
表示输入的张量;dim
表示需要聚合的维度;index
表示需要聚合的索引,可以是常量列表或张量;out
是可选的输出张量。
例如,对于张量 input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
,要求按照索引 torch.tensor([[0], [2], [3]])
在第一维度上进行聚合,则可以使用以下代码:
import torch
input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = torch.tensor([[0], [2], [3]])
output = torch.gather(input, 0, indices)
print(output)
# Output: tensor([[1],
# [5],
# [7]])
——————————————————————————————
不同:
tf.gather的索引为 [0, 2, 3]时,可以
在第一维度上进行聚合,但
如果torch.gather的索引也设置为[0, 2, 3],则会报错:
RuntimeError: Index tensor must have the same number of dimensions as input tensor
而若torch.gather的索引设置为[[0], [2], [3]]
就会得到像上面的输出:
# Output: tensor([[1],
# [5],
# [7]])
为达到同样的效果,我们先用
torch.unsqueeze函数将tensor添加一个维度从[0, 2, 3]变为[[0], [2], [3]], 再用repeat函数对索引进行更改:
import torch
input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = torch.tensor([0, 2, 3])
indices = torch.unsqueeze(indices, dim=1)
indices = indices.repeat(1, input.size(1))
output = torch.gather(input, 0, indices)
print(output)
# Output: tensor([[1, 2],
# [5, 6],
# [7, 8]])
repeat函数表示对这个索引进行重复,第一个参数为1,表示行不变,第二个参数为
input.size(1),表示列重复input.size(1)遍,即 两遍,得到的indices为: tensor([[0, 0], [2, 2], [3, 3]])
当遇到 多维张量 时,我们仍可这样操作
-----------------------------------------------------------------------
by the way,我们也可以对维度,也就是dim/axis进行操作,
我们可以对比一下
input = torch.tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
x = index = torch.tensor([[1], [0]])
index=index.repeat(1, input.size(1))
c = torch.gather(input,dim=0,index=index)
d = torch.gather(input,dim=1,index=x)
#tensor([[4, 5, 6],
# [1, 2, 3]])
#tensor([[2],
# [4]])
input = torch.tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
x = index = torch.tensor([[1], [0]])
index=index.repeat(1, input.size(1))
c = torch.gather(input,dim=0,index=index)
d = torch.gather(input,dim=1,index=index)
#tensor([[4, 5, 6],
# [1, 2, 3]])
#tensor([[2, 2, 2],
# [4, 4, 4]])
by the way way:
tf.gather_nd可以替换tf.gather,但是在用gather_nd会引入更多额外参数,对4-d tensor,假设我们想用tf.gather_nd替换tf.gather,就要提取出对应轴的元素,此时的indices就要把想要的元素对应索引组成一个矩阵就可以了。
那不如来思考一个形状为[2,3,4,5] 的parmas,
如何把tf.gather(params,axis = 3,indices=[0,2])用tf.gather_nd来输出同样的结果。
import tensorflow as tf
import numpy as np
a = tf.Variable(tf.random_uniform(shape=(2, 3, 4, 5), name="v"))
nnzs = [0,2]
nnzs = np.asarray(nnzs,"int32")
#initi= np.asarray(initi,"int32")
initi =np.zeros((2,3,4,nnzs.size,4),dtype=np.int)
print(initi.shape)
for i in range(initi.shape[0]):
for j in range(initi.shape[1]):
for k in range (initi.shape[2]):
for l in range(nnzs.size):
initi[i][j][k][l] =[i,j,k,nnzs[l]]
indices = tf.Variable(initial_value = initi, name="indices")
c = tf.gather_nd(a,initi)
d = tf.gather(a,indices= nnzs,axis =3)
print(c.get_shape())
print(d.get_shape())
e = tf.equal(c,d)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([e]))
最后输出一个全为true的array。