concat(),stack(),gather()用法

import tensorflow as tf
a = [[1,1,1],
     [2,2,2]]
b = [[3,3,3],
     [4,4,4]]
c = [[[5,5,5],
      [6,6,6]],
     [[7,7,7],
      [8,8,8]],
     [[9,9,9],
     [10,10,10]]
     ]

sess = tf.Session()
#tf.range():用于创建数字序列变量
print(sess.run(tf.range(8,13,2)))#起点是8,不超过13,增量为2,结果:[ 8 10 12]
print(sess.run(tf.range(8,13)))#步长默认1,结果[ 8  9 10 11 12]
print(sess.run(tf.range(8,3,-1.5)))#增量为-1.5,结果[8.  6.5 5.  3.5]
print(sess.run(tf.range(5)))#起点默认从0 开始,个数为5,结果[0 1 2 3 4]

#tf.shape():输出维度
print(sess.run(tf.shape(c)))#输出C的维度:[3 2 3]

#tf.concat():合并数组
x1 = tf.concat([a,b],axis=0)
x2 = tf.concat([a,b],axis=1)

print(sess.run(x1))#[[1 1 1]
                   # [2 2 2]
                   # [3 3 3]
                   # [4 4 4]]

print(sess.run(tf.shape(x1)))#[4 3]

print(sess.run(x2))#[[1 1 1 3 3 3]
                   # [2 2 2 4 4 4]]

print(sess.run(tf.shape(x2)))#[2 6]


#tf.stack
x1 = tf.stack([a,b],axis=0)
x2 = tf.stack([a,b],axis=1)

print(sess.run(x1))#[[[1 1 1]
                   # [2 2 2]]
                   # [[3 3 3]
                   # [4 4 4]]]

print(sess.run(tf.shape(x1)))#[2 2 3]

print(sess.run(x2))#[[[1 1 1]
                   # [3 3 3]]
                   # [[2 2 2]
                   # [4 4 4]]]

print(sess.run(tf.shape(x2)))#[2 2 3]

#tf.gather():把向量中某些索引值提取出来,得到新的向量
#tf.gather_nd():同上,适用于多维
index_a = tf.Variable([0])
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather_nd(x1,index_a)))#[[1 1 1]
                                         # [2 2 2]]
print(sess.run(tf.gather_nd(tf.gather_nd(x1,index_a),index_a)))#[1 1 1]


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值