1.tf.where()
第一种用法:
- where(condition)的用法
where(condition, x=None, y=None, name=None)
condition是bool型值,True/False
返回值,是condition中元素为True对应的索引
例如:
import tensorflow as tf
a = [[1,2,3],[4,5,6]]
b = [[1,0,3],[1,5,1]]
condition1 = [[True,False,False],
[False,True,True]]
condition2 = [[True,False,False],
[False,True,False]]
with tf.Session() as sess:
print(sess.run(tf.where(condition1)))
print(sess.run(tf.where(condition2)))
结果1:
[[0 0]
[1 1]
[1 2]]
结果2:
[[0 0]
[1 1]]
第二种用法:
where(condition, x=None, y=None, name=None)
condition是bool型值,True/False, x, y 相同维度,
返回值是对应元素,condition中元素为True的元素替换为x中的元素,为False的元素替换为y中对应元素。
x只负责对应替换True的元素,y只负责对应替换False的元素,x,y各有分工
由于是替换,返回值的维度,和condition,x , y都是相等的。
import tensorflow as tf
x = [[1,2,3],[4,5,6]]
y = [[7,8,9],[10,11,12]]
condition3 = [[True,False,False],
[False,True,True]]
condition4 = [[True,False,False],
[True,True,False]]
with tf.Session() as sess:
print(sess.run(tf.where(condition3,x,y)))
print(sess.run(tf.where(condition4,x,y)))
结果:
[[ 1 8 9]
[10 5 6]]
[[ 1 8 9]
[ 4 5 12]]
第三种用法:
tf.where(tf.greater(A, B), a, b)
tf.greater(a,b)
功能:通过比较a、b两个值的大小来输出True和False。
where会先判断第一项是否为true,如果为true则返回a;否则返回b;而greater则是比较A是否大于B,是的话返回true;否则返回false
2.tf.gather()
我们知道,ndarray和list都可以直接通过索引进行切片,但tensor却不行。不过TensorFlow提供了多个函数来进行张量切片,tf.gather()就是其中一种,其调用形式如下:
tf.gather(params, indices, validate_indices=None, name=None, axis=0)
参数:
- params:要进行切片的ndarray或list或tensor等
- indices:索引向量,其类型可以是ndarray、list、tensor等
- axis : 对哪个轴进行切片
函数功能:
从’params’的’axis’维根据’indices’的参数值获取切片。就是在axis维根据indices取某些值,最终得到新的tensor
示例:
1. params 的维数为1
import tensorflow as tf
import numpy as np
# params = np.random.randint(1, 10, 5)
# params = [2, 3, 4, 5, 6, 7]
params = tf.constant([2, 3, 4, 5, 6, 7])
# indices = np.array([2, 1, 4, 2])
# indices = [2, 1, 4, 2]
indices = tf.constant([2, 1, 4, 2])
tensor1 = tf.gather(params, indices)
with tf.Session() as sess:
# print(params)
print(sess.run(params))
print(sess.run(tensor1))
结果:
[2 3 4 5 6 7]
[4 3 6 4]
#分析:根据indices逐一取出params中对应索引的元素,并组成新的张量。
2. params 的维数为2
import tensorflow as tf
import numpy as np
params = np.random.randint(1, 10, (4, 5))
indices = tf.constant([2, 1, 0, 2])
tensor0 = tf.gather(params, indices, axis=0)
tensor1 = tf.gather(params, indices, axis=1)
with tf.Session() as sess:
print('params =', params)
print('tensor0 =', sess.run(tensor0))
print('tensor1 =', sess.run(tensor1))
结果:
params = [[5 1 4 7 2]
[1 8 9 1 7]
[2 1 8 7 2]
[8 9 5 8 7]]
tensor0 = [[2 1 8 7 2]
[1 8 9 1 7]
[5 1 4 7 2]
[2 1 8 7 2]]
tensor1 = [[4 1 5 4]
[9 8 1 9]
[8 1 2 8]
[5 9 8 5]]
对于二维params,
当indices是标量且是张量时,得到的结果不会降维;
当indices是标量且是ndarray时,得到的结果会降维。
import tensorflow as tf
import numpy as np
params = np.random.randint(1, 10, (3, 4))
indices1 = tf.constant([2])
indices2 = 2
tensor1 = tf.gather(params, indices1, axis=0)
tensor2 = tf.gather(params, indices2, axis=0)
with tf.Session() as sess:
print('params =', params)
print('tensor1 =', sess.run(tensor1))
print('tensor2 =', sess.run(tensor2))
结果:
params = [[9 2 1 7]
[7 8 2 3]
[9 7 2 9]]
tensor1 = [[9 7 2 9]]
tensor2 = [9 7 2 9]
3.tf.squeeze()
tf.squeeze(input, axis=None, name=None, squeeze_dims=None)
该函数返回一个张量,这个张量是将原始input中所有维度为1的那些维都删掉的结果。
axis可以用来指定要删掉的为1的维度,此处要注意指定的维度必须确保其是1,否则会报错。
import tensorflow as tf
input_tensor = tf.ones((2, 1, 1, 3, 2))
new_tensor1 = tf.squeeze(input_tensor)
new_tensor2 = tf.squeeze(input_tensor, [1])
with tf.Session() as sess:
print(sess.run(tf.shape(input_tensor)))
print(sess.run(tf.shape(new_tensor1)))
print(sess.run(tf.shape(new_tensor2)))
结果:
[2 1 1 3 2]
[2 3 2]
[2 1 3 2]
附加:
tf.less()、tf.greater()、tf.equal()等比较函数
这几个函数用于逐元素比较两个张量的大小,并返回比较结果(True or False)构成的布尔型张量。下面以tf.less()为例:
tf.less(x, y, name=None)
tf.less()返回了两个张量各元素比较(x<y)得到的真假值组成的张量。
提示:
- tf.less()支持broadcast机制;
- tf.less(x, y)中的 x 和 y 可以是tensor、ndarray、list等。
x = tf.constant([[1, 2, 3], [4, 5, 6]])
y1 = tf.constant([[2, 1, 2], [2, 6, 7]])
y2 = tf.constant([3, 6, 9])
y3 = tf.constant([3])
with tf.Session() as sess:
print(sess.run(tf.less(x, y1)))
print(sess.run(tf.less(x, y2)))
print(sess.run(tf.less(x, y3)))
结果:
[[ True False False]
[False True True]]
[[ True True True]
[False True True]]
[[ True True False]
[False False False]]
总结:
- tf.less(x, y) —— x < y 为True
- tf.equal(x, y) —— x == y 为True
- tf.greater(x, y) —— x > y 为True
- tf.greater_equal(x, y) —— x >= y 为True
- tf.less_equal(x, y) —— x <= y 为True