tf.batch_gather,tf.gather,tf.gather_nd

tf.batch_gather()

简单来说,batch_gather就是通过索引来获取数组的值

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.batch_gather(tensor_a,tensor_b)))

结果:

[[1]
 [5]
 [9]]

如果提取索引超过数据集中的维度,则超出维度的数据全部添0,这常用于插值

import tensorflow as tf

tensor_a = tf.Variable([[[1,2,3],[4,5,6],[7,8,9]]])
print(tensor_a.shape)
tensor_b = tf.Variable([[0,1,2,3,4]],dtype=tf.int32)
print(tensor_b.shape)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
    print(sess.run(tf.batch_gather(tensor_a,tensor_b)).shape)

结果:

(1, 3, 3)
(1, 5)

[[[1 2 3]
  [4 5 6]
  [7 8 9]
  [0 0 0]
  [0 0 0]]]
  
(1, 5, 3)

tf.gather()

tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来

import tensorflow as tf
 
a = tf.Variable([[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]])
index_a = tf.Variable([0,2])
 
b = tf.Variable([1,2,3,4,5,6,7,8,9,10])
index_b = tf.Variable([2,4,6,8])
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.gather(a, index_a)))
    print(sess.run(tf.gather(b, index_b)))

结果

 [[ 1  2  3  4  5]
 [11 12 13 14 15]]
 
 [3 5 7 9]

以下是我在看代码时,测试代码发现的一个用法

import tensorflow as tf
import numpy as np

reducing_list = tf.range(8, dtype=tf.int32)
print(reducing_list)
inserted_value = tf.zeros((1,), dtype=tf.int32)
print(inserted_value)
aaa =np.array([1,5,6,2,4,7,4,1,2,5,6,3,2,5,4])
valid_labels_init = tf.Variable(aaa)

for ign_label in range(1):
    reducing_list1 = tf.concat([reducing_list[:ign_label], inserted_value, reducing_list[ign_label:]],axis=0)
    print(reducing_list[:ign_label])
    print(reducing_list[ign_label:])
    print(reducing_list1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(reducing_list))
    print(sess.run(inserted_value))
    print(sess.run(reducing_list[:ign_label]))
    print(sess.run(reducing_list[ign_label:]))
    print(sess.run(reducing_list1))

valid_labels = tf.gather(reducing_list1, valid_labels_init)
print(valid_labels)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(valid_labels))

首先生成一个[0,1,2,3,4,5,6,7]的tensor和一个[0]的Tensor,将这两个Tensor拼接为reducing_list1,可得到[0,0,1,2,3,4,5,6,7]。
aaa是测试设的一个含1-7编号的标签数据,tf.gather(reducing_list1,valid_labels_init)后,获得了一个0-6编号的标签数据aaa,相当于aaa中所有标签全部减1。

结果如下:

Tensor("range_2:0", shape=(8,), dtype=int32)
Tensor("zeros_2:0", shape=(1,), dtype=int32)

Tensor("strided_slice_10:0", shape=(0,), dtype=int32)
Tensor("strided_slice_11:0", shape=(8,), dtype=int32)

Tensor("concat_4:0", shape=(9,), dtype=int32)
[0 1 2 3 4 5 6 7]
[0]
[]
[0 1 2 3 4 5 6 7]
[0 0 1 2 3 4 5 6 7]
Tensor("GatherV2_2:0", shape=(15,), dtype=int32)

[0 4 5 1 3 6 3 0 1 4 5 2 1 4 3]

可以看到标签变化

原始便签:[1,5,6,2,4,7,4,1,2,5,6,3,2,5,4]
gather后的便签:[0 4 5 1 3 6 3 0 1 4 5 2 1 4 3]

我现在还不清楚这种用法有什么含义,以及为什么这样的,如果有知道的好兄弟麻烦在评论处告知我,感谢!!

tf.gather_nd()

内容来自于[博客](https://blog.csdn.net/lllxxq141592654/article/details/85400177)

tf.gather_nd(
    params,
    indices,
    name=None
)

按照indices的格式从params中抽取切片(合并为一个Tensor),indices是一个K维整数Tensor

例子

import tensorflow as tf

a = tf.Variable([[1, 2, 3, 4, 5],
                 [6, 7, 8, 9, 10],
                 [11, 12, 13, 14, 15]])
index_a1 = tf.Variable([[0, 2], [0, 4], [2, 2]])  # 随便选几个
index_a2 = tf.Variable([0, 1])  # 0行1列的元素——2
index_a3 = tf.Variable([[0], [1]])  # [第0行,第1行]
index_a4 = tf.Variable([0])  # 第0行


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.gather_nd(a, index_a1)))
    print(sess.run(tf.gather_nd(a, index_a2)))
    print(sess.run(tf.gather_nd(a, index_a3)))
    print(sess.run(tf.gather_nd(a, index_a4)))
[ 3  5 13]
2
[[ 1  2  3  4  5]
 [ 6  7  8  9 10]]
[1 2 3 4 5]
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值