TensorFlow中张量的约减(Reduce)方向

张量的Reduce方向


     对于多维张量而言, 约减的方向是一个需要明确的问题。在TensorFlow中, 提供了很多关于约减的函数, 如tf.reduce_sum, tf.reduce_mean, tf.reduce_max, tf.reduce_min等函数, 它们的约减原理都是一样的,即从一大批数据中,不断减少数据量,直到找到满足要求的数据。
下面以tf.reduce_sum()来说明张量的约减方向。原型如下:

tf.reduce_sum(
    input_tensor,
    axis=None,
    keepdims=False,
    name=None,
    reduction_indices=None)

只有第一个参数input_tensor是必须的。对张量(多维数组)而言,约减是有方向性的。第2个参数axis,决定了约减的轴方向。axis=0,垂直方向上约减, axis=1,水平方向上约减。并且约减可以有先后顺序。因此axis的值可以是一个向量,比如axis=[1,0], 表示先水平方向约减,然后垂直方向上约减。axis默认为None,表示所有维度的张量都会依次约减。

参数keepdims为True, 那么每个维度的张量被约减到长度为1, 即保留了维度信息。

下面给出代码以及运行结果:

x = tf.constant([[1,1,1], [1,1,1]])
a = tf.reduce_sum(x)
b = tf.reduce_sum(x, 0)  # 垂直方向上约减
c = tf.reduce_sum(x, 1)  # 水平方向上约减
d = tf.reduce_sum(x, 1, keepdims=True)  # 每个维度的张量被约减到长度为1, 即保留了维度信息
e = tf.reduce_sum(x, [0, 1])  #先垂直后水平
with tf.Session() as sess:
    print('a =', sess.run(a))
    print('b =', sess.run(b))
    print('c =', sess.run(c))
    print('d =', sess.run(d))
    print('e =', sess.run(e))

    结果为:

a = 6
b = [2 2 2]
c = [3 3]
d = [[3]
 [3]]
e = 6

上述的解释虽然直观,但有很大的局限性。这种轴的概念,在维度小于2时,容易理解。且对于0表示垂直方向, 1表示水平方向是人为强加的。当在维度>=3时,就难以找到直观可理解的方向。

更加普适的解释应该按张量括号层次的方式来理解。张量括号由外到内,对应从小到大的维数。

    当指定reduce_sum函数的axis=0时,就是在第0个维度的元素之间进行sum操作,也就是除掉最外层括号后对应的两个元素,即[[1,1,1],[2,2,2]],[[3,3,3],[4,4,4]],然后对同一个括号层次下的这两个张量实施加法约减操作,即张量[[1,1,1],[2,2,2]]和
张量[[3,3,3],[4,4,4]]整体相加, 其结果为[[4,4,4],[6,6,6]]。没有被约减的维度,其括号层次保持不变。 

类似的,当axis=1时,就是在第1个维度的元素之间进行sum操作,也就是去掉中间层括号对应的元素[1,1,1],[2,2,2]和[3,3,3],[4,4,4]。需要注意的时, 原来在同一个括号层次内的张量两两相加,即[1,1,1]和[2,2,2]向量相加,[3,3,3]和[4,4,4]向量相加。
没有被约减的维度,其括号保持不变,结果得到 [[3,3,3],[7,7,7]]。
     当axis=2时,就是除掉最内层的括号,然后在最内层括号的元素之间进行sum操作。即1+1+1=3,2+2+2=6,3+3+3=9,4+4+4=12。实施约减之后,该层次括号消失,其他维度的括号保留。结果得到[[3,6],[9,12]]。
 这里为了便于区分,用逗号','将同一层次的不同元素隔开,实际上TensorFlow中,不同元素是用 空格隔开的。 事实上,每一个维度的约减,在实施之后,该维度都会消失。
 
 下面用一个简单的程序来验证上面的描述:

x1 = tf.constant([
    [[1,1,1],[2,2,2]],
    [[3,3,3],[4,4,4]]
    ])
z0 = tf.reduce_sum(x1, 0)
z1 = tf.reduce_sum(x1, 1)
z2 = tf.reduce_sum(x1, 2)
z3 = tf.reduce_sum(x1)
with tf.Session() as sess:
    print("============>\n", sess.run(z0))
    print("============>\n", sess.run(z1))
    print("============> \n", sess.run(z2))
    print("============>\n ", sess.run(z3))

结果如下:

============>
 [[4 4 4]
 [6 6 6]]
============>
 [[3 3 3]
 [7 7 7]]
============>
 [[ 3  6]
 [ 9 12]]
============>
  30

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值