在Tensorflow(TF)中,常常会看到“reduce_"系列的东西,比如reduce_sum,reduce_mean…,刚刚开始我觉得,求和就sum就行,为什么加一个reduce_前缀?
后来我意识到,每一次求和或者求平均值,其实都是自动的对Tensor进行了降维,例如,对向量[1,2,3,4]求和,得到一个标量10。
刚刚开始接触reduce系列,通常需要花一些时间在脑海中可视化哪一个维度被减掉了,但是在熟悉了tensor的 shape/dimensions/indexing之后,很多debug会节省很多时间。
其实TF的sum和mean与numpy,matlab都类似,index=0表示按列求和或者取平均,index=1表示按照行求和或者取平均,只是TF中处理的多是4维或者5维[batch_size, time_steps, width, height, channels]张量,不那么容易理解,下面以一个三维Tensor为例子来操作。
# Let's initialize the tensor.
In [3]: x = tf.constant([[[1,2,3,4,5], [4,5,6,7,8]],
[[2,4,6,8,10],[3,6,9,12,15]],
[[8,8,8,8,8], [9,9,9,9,9]]])
In [9]: sess = tf.InteractiveSession()
# Let's see how it looks.
In [10]: x.eval()
Out[10]:
array([[[ 1, 2, 3, 4, 5],
[ 4, 5, 6, 7, 8]],
[[ 2, 4, 6, 8, 10],
[ 3, 6, 9, 12, 15]],
[[ 8, 8, 8, 8, 8],
[ 9, 9, 9, 9, 9]]], dtype=int32)
这里顺便说一下tf.InteractiveSession和tf.Session的区别,经常做大项目的同学知道,tf.Session一般都是在整个网络的计算图构建完了之后,再来run初始化变量和feed数据。但是tf.InteractiveSession不同,它支持先开sess,然后加入计算图,也就是说,如果是做小实验,熟悉Tensorflow的各个函数,用tf.InteractiveSession和eval()函数是合适的。
tf.reduce_max(x,0) 输出是什么呢?
tf.reduce_max(x, 0) 减掉的是第0维,所以这个(3,2,5)的tensor中的3会消失。
In [11]: tf.reduce_max(x,0)
Out[11]: <tf.Tensor 'Max:0' shape=(2, 5) dtype=int32>
In [12]: tf.reduce_max(x,0).eval()
Out[12]:
array([[ 8, 8, 8, 8, 10],
[ 9, 9, 9, 12, 15]], dtype=int32)
对于第一维是一样的,结果的shape是(3,5)
In [14]: tf.reduce_max(x,1)
Out[14]: <tf.Tensor 'Max_3:0' shape=(3, 5) dtype=int32>
In [15]: tf.reduce_max(x,1).eval()
Out[15]:
array([[ 4, 5, 6, 7, 8],
[ 3, 6, 9, 12, 15],
[ 9, 9, 9, 9, 9]], dtype=int32)
在一些实际代码中,可能会看到reduce中不是减掉单个维度而是减掉几个维度,比如tf.reduce_max(x,(1,2))意味着我们想要每个batch中的最大值(batch_size通常是dimension 0),这种情况下坍缩的就是维度1和2.
In [16]: tf.reduce_max(x,(1,2)).eval()
Out[16]: array([ 8, 15, 9], dtype=int32)
In [17]: tf.reduce_max(x,(1,2))
Out[17]: <tf.Tensor 'Max_11:0' shape=(3,) dtype=int32>
你会发现这改变了最初tensor的shape,怎么办?
加上那句keepdims=True.
In [22]: tf.reduce_max(x,(1,2), keepdims=True).eval()
Out[22]:
array([[[ 8]],
[[15]],
[[ 9]]], dtype=int32)
In [23]: tf.reduce_max(x,(1,2), keepdims=True)
Out[23]: <tf.Tensor 'Max_10:0' shape=(3, 1, 1) dtype=int32>