TensorFlow常用函数tf.where()、tf.gather()、tf.squeeze()详解!!

1.tf.where()

第一种用法:

  1. where(condition)的用法
  • where(condition, x=None, y=None, name=None)

 condition是bool型值,True/False
返回值,是condition中元素为True对应的索引
例如:

import tensorflow as tf
a = [[1,2,3],[4,5,6]]
b = [[1,0,3],[1,5,1]]
condition1 = [[True,False,False],
             [False,True,True]]
condition2 = [[True,False,False],
             [False,True,False]]
with tf.Session() as sess:
    print(sess.run(tf.where(condition1)))
    print(sess.run(tf.where(condition2)))

结果1[[0 0]
 [1 1]
 [1 2]]
结果2[[0 0]
 [1 1]]

第二种用法:

  • where(condition, x=None, y=None, name=None)
    condition是bool型值,True/False, x, y 相同维度,
    返回值是对应元素,condition中元素为True的元素替换为x中的元素,为False的元素替换为y中对应元素。
    x只负责对应替换True的元素,y只负责对应替换False的元素,x,y各有分工
    由于是替换,返回值的维度,和condition,x , y都是相等的。
import tensorflow as tf
x = [[1,2,3],[4,5,6]]
y = [[7,8,9],[10,11,12]]
condition3 = [[True,False,False],
             [False,True,True]]
condition4 = [[True,False,False],
             [True,True,False]]
with tf.Session() as sess:
    print(sess.run(tf.where(condition3,x,y)))
    print(sess.run(tf.where(condition4,x,y)))  
结果:    
[[ 1  8  9]
[10  5  6]]
[[ 1  8  9]
[ 4  5 12]]

第三种用法:

  • tf.where(tf.greater(A, B), a, b)
    tf.greater(a,b)
    功能:通过比较a、b两个值的大小来输出True和False。
    where会先判断第一项是否为true,如果为true则返回a;否则返回b;而greater则是比较A是否大于B,是的话返回true;否则返回false

2.tf.gather()

 我们知道,ndarray和list都可以直接通过索引进行切片,但tensor却不行。不过TensorFlow提供了多个函数来进行张量切片,tf.gather()就是其中一种,其调用形式如下:

tf.gather(params, indices, validate_indices=None, name=None, axis=0)

参数:

  • params:要进行切片的ndarray或list或tensor等
  • indices:索引向量,其类型可以是ndarray、list、tensor等
  • axis : 对哪个轴进行切片
    函数功能:
    从’params’的’axis’维根据’indices’的参数值获取切片。就是在axis维根据indices取某些值,最终得到新的tensor
    示例:
    1. params 的维数为1
import tensorflow as tf
import numpy as np
# params = np.random.randint(1, 10, 5)
# params = [2, 3, 4, 5, 6, 7]
params = tf.constant([2, 3, 4, 5, 6, 7])
# indices = np.array([2, 1, 4, 2])
# indices = [2, 1, 4, 2]
indices = tf.constant([2, 1, 4, 2])
tensor1 = tf.gather(params, indices)
with tf.Session() as sess:
    # print(params)
    print(sess.run(params))
    print(sess.run(tensor1))
结果:
[2 3 4 5 6 7]
[4 3 6 4]

#分析:根据indices逐一取出params中对应索引的元素,并组成新的张量。

2. params 的维数为2

import tensorflow as tf
import numpy as np

params = np.random.randint(1, 10, (4, 5))
indices = tf.constant([2, 1, 0, 2])
tensor0 = tf.gather(params, indices, axis=0)
tensor1 = tf.gather(params, indices, axis=1)
with tf.Session() as sess:
    print('params =', params)
    print('tensor0 =', sess.run(tensor0))
    print('tensor1 =', sess.run(tensor1))
结果:
params = [[5 1 4 7 2]
 		  [1 8 9 1 7]
 		  [2 1 8 7 2]
 		  [8 9 5 8 7]]
tensor0 = [[2 1 8 7 2]
 		   [1 8 9 1 7]
		   [5 1 4 7 2]
		   [2 1 8 7 2]]
tensor1 = [[4 1 5 4]
 		   [9 8 1 9]
 	 	   [8 1 2 8]
 		   [5 9 8 5]]

 对于二维params,
 当indices是标量且是张量时,得到的结果不会降维;
 当indices是标量且是ndarray时,得到的结果会降维。

import tensorflow as tf
import numpy as np

params = np.random.randint(1, 10, (3, 4))
indices1 = tf.constant([2])
indices2 = 2
tensor1 = tf.gather(params, indices1, axis=0)
tensor2 = tf.gather(params, indices2, axis=0)
with tf.Session() as sess:
    print('params =', params)
    print('tensor1 =', sess.run(tensor1))
    print('tensor2 =', sess.run(tensor2))
结果:
params = [[9 2 1 7]
	      [7 8 2 3]
 		  [9 7 2 9]]
tensor1 = [[9 7 2 9]]
tensor2 = [9 7 2 9]

3.tf.squeeze()

tf.squeeze(input, axis=None, name=None, squeeze_dims=None)
 该函数返回一个张量,这个张量是将原始input中所有维度为1的那些维都删掉的结果。
 axis可以用来指定要删掉的为1的维度,此处要注意指定的维度必须确保其是1,否则会报错。

import tensorflow as tf
input_tensor = tf.ones((2, 1, 1, 3, 2))
new_tensor1 = tf.squeeze(input_tensor)
new_tensor2 = tf.squeeze(input_tensor, [1])
with tf.Session() as sess:
    print(sess.run(tf.shape(input_tensor)))
    print(sess.run(tf.shape(new_tensor1)))
    print(sess.run(tf.shape(new_tensor2)))
结果:
[2 1 1 3 2]
[2 3 2]
[2 1 3 2]

附加:

tf.less()、tf.greater()、tf.equal()等比较函数

 这几个函数用于逐元素比较两个张量的大小,并返回比较结果(True or False)构成的布尔型张量。下面以tf.less()为例:
tf.less(x, y, name=None)
tf.less()返回了两个张量各元素比较(x<y)得到的真假值组成的张量。
提示:

  • tf.less()支持broadcast机制;
  • tf.less(x, y)中的 x 和 y 可以是tensor、ndarray、list等。
x = tf.constant([[1, 2, 3], [4, 5, 6]])
y1 = tf.constant([[2, 1, 2], [2, 6, 7]])
y2 = tf.constant([3, 6, 9])
y3 = tf.constant([3])
with tf.Session() as sess:
    print(sess.run(tf.less(x, y1)))
    print(sess.run(tf.less(x, y2)))
    print(sess.run(tf.less(x, y3)))
结果:
[[ True False False]
 [False  True  True]]
[[ True  True  True]
 [False  True  True]]
[[ True  True False]
 [False False False]]

总结:

  • tf.less(x, y) —— x < y 为True
  • tf.equal(x, y) —— x == y 为True
  • tf.greater(x, y) —— x > y 为True
  • tf.greater_equal(x, y) —— x >= y 为True
  • tf.less_equal(x, y) —— x <= y 为True
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值