【Tensorflow】有关于tensor的shape

最近改结构,但是因为不熟练的缘故,所以不能很快的找到axis,用一下代码检测所选的方向是否正确。

import numpy as np
import tensorflow as tf
import warnings
warnings.filterwarnings('ignore')

a = np.reshape(range(48),(4,6,2))
print("ori",a)
sess = tf.InteractiveSession()
res = tf.split(a, axis=1, num_or_size_splits=[4,2])
# print(res)
b1 = res[0] # (4, 4, 2)
b2 = res[1] # (4, 2, 2)
print("part1",res[0].eval())
print("part2",res[1].eval())
# 求b2第二个维度平均
c = tf.reduce_mean(b2, axis=1, keepdims=True)
print(c) # shape=(4, 1, 2)
print(c.eval())

d = tf.concat([b1,c],axis=1)
print(d.eval())
print(d.shape)

"""
ori [[[ 0  1]
  [ 2  3]
  [ 4  5]
  [ 6  7]
  [ 8  9]
  [10 11]]

 [[12 13]
  [14 15]
  [16 17]
  [18 19]
  [20 21]
  [22 23]]

 [[24 25]
  [26 27]
  [28 29]
  [30 31]
  [32 33]
  [34 35]]

 [[36 37]
  [38 39]
  [40 41]
  [42 43]
  [44 45]
  [46 47]]]

part1 [[[ 0  1]
  [ 2  3]
  [ 4  5]
  [ 6  7]]

 [[12 13]
  [14 15]
  [16 17]
  [18 19]]

 [[24 25]
  [26 27]
  [28 29]
  [30 31]]

 [[36 37]
  [38 39]
  [40 41]
  [42 43]]]
part2 [[[ 8  9]
  [10 11]]

 [[20 21]
  [22 23]]

 [[32 33]
  [34 35]]

 [[44 45]
  [46 47]]]
  
Tensor("Mean:0", shape=(4, 1, 2), dtype=int64)
[[[ 9 10]]

 [[21 22]]

 [[33 34]]

 [[45 46]]]
[[[ 0  1]
  [ 2  3]
  [ 4  5]
  [ 6  7]
  [ 9 10]]

 [[12 13]
  [14 15]
  [16 17]
  [18 19]
  [21 22]]

 [[24 25]
  [26 27]
  [28 29]
  [30 31]
  [33 34]]

 [[36 37]
  [38 39]
  [40 41]
  [42 43]
  [45 46]]]
(4, 5, 2)
"""

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值