tf.batch_gather()
简单来说,batch_gather就是通过索引来获取数组的值
import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
结果:
[[1]
[5]
[9]]
如果提取索引超过数据集中的维度,则超出维度的数据全部添0,这常用于插值
import tensorflow as tf
tensor_a = tf.Variable([[[1,2,3],[4,5,6],[7,8,9]]])
print(tensor_a.shape)
tensor_b = tf.Variable([[0,1,2,3,4]],dtype=tf.int32)
print(tensor_b.shape)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
print(sess.run(tf.batch_gather(tensor_a,tensor_b)).shape)
结果:
(1, 3, 3)
(1, 5)
[[[1 2 3]
[4 5 6]
[7 8 9]
[0 0 0]
[0 0 0]]]
(1, 5, 3)
tf.gather()
tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来
import tensorflow as tf
a = tf.Variable([[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]])
index_a = tf.Variable([0,2])
b = tf.Variable([1,2,3,4,5,6,7,8,9,10])
index_b = tf.Variable([2,4,6,8])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather(a, index_a)))
print(sess.run(tf.gather(b, index_b)))
结果
[[ 1 2 3 4 5]
[11 12 13 14 15]]
[3 5 7 9]
以下是我在看代码时,测试代码发现的一个用法
import tensorflow as tf
import numpy as np
reducing_list = tf.range(8, dtype=tf.int32)
print(reducing_list)
inserted_value = tf.zeros((1,), dtype=tf.int32)
print(inserted_value)
aaa =np.array([1,5,6,2,4,7,4,1,2,5,6,3,2,5,4])
valid_labels_init = tf.Variable(aaa)
for ign_label in range(1):
reducing_list1 = tf.concat([reducing_list[:ign_label], inserted_value, reducing_list[ign_label:]],axis=0)
print(reducing_list[:ign_label])
print(reducing_list[ign_label:])
print(reducing_list1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(reducing_list))
print(sess.run(inserted_value))
print(sess.run(reducing_list[:ign_label]))
print(sess.run(reducing_list[ign_label:]))
print(sess.run(reducing_list1))
valid_labels = tf.gather(reducing_list1, valid_labels_init)
print(valid_labels)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(valid_labels))
首先生成一个[0,1,2,3,4,5,6,7]的tensor和一个[0]的Tensor,将这两个Tensor拼接为reducing_list1,可得到[0,0,1,2,3,4,5,6,7]。
aaa是测试设的一个含1-7编号的标签数据,tf.gather(reducing_list1,valid_labels_init)后,获得了一个0-6编号的标签数据aaa,相当于aaa中所有标签全部减1。
结果如下:
Tensor("range_2:0", shape=(8,), dtype=int32)
Tensor("zeros_2:0", shape=(1,), dtype=int32)
Tensor("strided_slice_10:0", shape=(0,), dtype=int32)
Tensor("strided_slice_11:0", shape=(8,), dtype=int32)
Tensor("concat_4:0", shape=(9,), dtype=int32)
[0 1 2 3 4 5 6 7]
[0]
[]
[0 1 2 3 4 5 6 7]
[0 0 1 2 3 4 5 6 7]
Tensor("GatherV2_2:0", shape=(15,), dtype=int32)
[0 4 5 1 3 6 3 0 1 4 5 2 1 4 3]
可以看到标签变化
原始便签:[1,5,6,2,4,7,4,1,2,5,6,3,2,5,4]
gather后的便签:[0 4 5 1 3 6 3 0 1 4 5 2 1 4 3]
我现在还不清楚这种用法有什么含义,以及为什么这样的,如果有知道的好兄弟麻烦在评论处告知我,感谢!!
tf.gather_nd()
内容来自于[博客](https://blog.csdn.net/lllxxq141592654/article/details/85400177)
tf.gather_nd(
params,
indices,
name=None
)
按照indices
的格式从params
中抽取切片(合并为一个Tensor),indices
是一个K维整数Tensor
例子
import tensorflow as tf
a = tf.Variable([[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]])
index_a1 = tf.Variable([[0, 2], [0, 4], [2, 2]]) # 随便选几个
index_a2 = tf.Variable([0, 1]) # 0行1列的元素——2
index_a3 = tf.Variable([[0], [1]]) # [第0行,第1行]
index_a4 = tf.Variable([0]) # 第0行
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather_nd(a, index_a1)))
print(sess.run(tf.gather_nd(a, index_a2)))
print(sess.run(tf.gather_nd(a, index_a3)))
print(sess.run(tf.gather_nd(a, index_a4)))
[ 3 5 13]
2
[[ 1 2 3 4 5]
[ 6 7 8 9 10]]
[1 2 3 4 5]