tf.gather
numpy 支持用 ndarray 索引:
import numpy as np
arr = np.arange(9).reshape(3, 3)
idx = np.array([0, 2])
print(arr[idx])
但 tensorflow 中非 scalar 的 tensor 不可以直接用作下标:
import tensorflow as tf
arr = tf.reshape(tf.range(9), [3, 3])
idx = tf.constant([0, 2])
one = tf.constant(1)
with tf.Session() as sess:
print(sess.run(arr[one])) # 可以
print(sess.run(arr[0:1])) # 可以
print(sess.run(arr[idx])) # 报错
要实现类似的功能,用 tf.gather
:
import tensorflow as tf
arr = tf.reshape(tf.range(9), [3, 3])
idx = tf.constant([0, 2])
with tf.Session() as sess:
print(sess.run(tf.gather(arr, idx)))
tf.gather_nd
这次的目标是:给出矩阵
A
n
×
m
A_{n\times m}
An×m 和索引向量
b
n
×
1
b_{n\times1}
bn×1,取各 A[i][b[i]],即 A 的每行都取一个元素,下标由 b[i] 决定。用到tf.gather_nd
。
tf.gather_nd 用元素的「坐标」选元素,即传入的第二个参数indices
是要选的那些元素的坐标的序列。例如对于上述的目标,indices 就是各[i, b[i]]
组成的序列。b[i] 已经有,只要补上行坐标 i 就行。
Example
- 这里同时也实现了 tensorflow 的 tensor 随机索引,即生成随机索引向量 b,用于索引张量 A 的分量。
import tensorflow as tf
import numpy as np
n = 3
m = 4
# 备选数组
A = tf.constant(np.arange(n * m).reshape(n, m))
# 随机生成列 id
b = tf.random_uniform([n, 1],
minval=0, maxval=m, # 列 id 范围:[0, m)
dtype=tf.int32)
# 补上行 id
row_id = tf.range(n, dtype="int32")[:, None] # 形状:[n, 1]
#print(row_id.shape.as_list())
#print(b.shape.as_list())
# 拼在一起组成完整坐标
idx = tf.concat([row_id, b], axis=1)
# 选元素
elem = tf.gather_nd(A, idx)
with tf.Session() as sess:
A, b, r, i, e = sess.run([A, b, row_id, idx, elem])
print("A:\n", A)
print("b:\n", b)
print("row id:\n", r)
print("indices:\n", i)
print("elem:\n", e)
结果:
A:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
b:
[[0]
[1]
[1]]
row id:
[[0]
[1]
[2]]
indices:
[[0 0]
[1 1]
[2 1]]
elem:
[0 5 9]
(0,5,9)
即 ( A[0][0], A[1][1], A[2][1] )
tf.batch_gather
用 tf.batch_gather
和 tf.argsort
/tf.top_k
实现矩阵分行排序。场景是:用 A 的行数据 argsort 得出的 indices 来对 B 行数据排序。
import tensorflow as tf
a = tf.constant([
[1, 0, 3, 2, 5],
[4, 7, 9, 8, 6]
])
b = tf.reshape(tf.range(10, 20), [2, 5])
# k_idx = tf.argsort(ham) # tf 1.12 无 `argsort`…用 top_k 代替
k_val, k_idx = tf.math.top_k(- a, a.shape[1]) # minus for ascending
b_sort = tf.batch_gather(b, k_idx)
with tf.Session() as sess:
a, b, k_idx, b_sort = sess.run([a, b, k_idx, b_sort])
print("a:\n", a)
print("b:\n", b)
print("k_idx:\n", k_idx)
print("b_sort:\n", b_sort)
结果:
a:
[[1 0 3 2 5]
[4 7 9 8 6]]
b:
[[10 11 12 13 14]
[15 16 17 18 19]]
k_idx:
[[1 0 3 2 4]
[0 4 1 3 2]]
b_sort:
[[11 10 13 12 14]
[15 19 16 18 17]]
tf.scatter_nd_update / tf.tensor_scatter_nd_update
将高级索引用于左值。注意 tensorflow 需要对 index 升维,详见代码。
in numpy
这个功能对应的 numpy 示例
import numpy as np
a = np.arange(12).reshape(3, 4)
print("before:\n", a)
idx = np.array([0, 2]) # numpy 的 index 不 需要升维
val = np.array([
[11, 12, 13, 14],
[15, 16, 17, 18]
])
a[idx] = val # 左值用高级索引
print("after:\n", a)
输出:
before:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
after:
[[11 12 13 14]
[ 4 5 6 7]
[15 16 17 18]]
in tensorflow 1.12
用 tf.scatter_nd_update
,见 [7]。
import tensorflow as tf
sess = tf.Session()
a = tf.Variable(tf.reshape(tf.range(12), [3, 4]))
sess.run(tf.global_variables_initializer())
print("before:\n", sess.run(a))
idx = tf.constant([[0], [2]]) # index 升维
val = tf.constant([
[11, 12, 13, 14],
[15, 16, 17, 18]
])
update = a.assign(tf.scatter_nd_update(a, idx, val)) # 左值用高级索引
print("after:\n", sess.run([a, update])[0])
sess.close()
输出:
before:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
after:
[[11 12 13 14]
[ 4 5 6 7]
[15 16 17 18]]
in tensorflow 2.1
用 tf.tensor_scatter_nd_update
,见 [8]。
import tensorflow as tf
a = tf.Variable(tf.reshape(tf.range(12), [3, 4]))
print("before:\n", a)
idx = tf.constant([[0], [2]]) # index 升维
val = tf.constant([
[11, 12, 13, 14],
[15, 16, 17, 18]
])
a.assign(tf.tensor_scatter_nd_update(a, idx, val)) # 左值用高级索引
print("after:\n", a)
输出:
before:
<tf.Variable 'Variable:0' shape=(3, 4) dtype=int32, numpy=
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)>
after:
<tf.Variable 'Variable:0' shape=(3, 4) dtype=int32, numpy=
array([[11, 12, 13, 14],
[ 4, 5, 6, 7],
[15, 16, 17, 18]], dtype=int32)>
References
- TensorFlow - numpy-like tensor indexing
- Generalize slicing and slice assignment ops (including gather and scatter) #206
- TF 中的 indexing 和 slicing
- tf.gather和tf.gather_nd的详细用法–tensorflow通过索引取tensor里的数据
- 从Tensorflow中从另一个中挑选随机张量
- tf.gather tf.gather_nd 和 tf.batch_gather 使用方法
- Tensorflow深度学习之三十二: tf.scatter_nd_update
- tf.tensor_scatter_nd_update