【深度学习API】TensorFlow - tf.concat()

tf.concat()

作用:拼接张量,主要参数为values,axis;

values为需要合并的list

axis决定了这些list中的张量如何合并,axis=0,合并第一维,axis=1,合并第二维,axis=2合并第三维,axis=3合并第四维度,前提是,除了需要合并的那一位纬度,其他不合并的纬度属性必须相同;

axis=3时,四个张量最后一维不同,合并第四维,代码:

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        a1 = tf.random_normal((100,28,28,64))
        a2 = tf.random_normal((100,28,28,128))
        a3 = tf.random_normal((100,28,28,32))
        a4 = tf.random_normal((100,28,28,32))

        b4 = tf.concat(values=[a1,a2,a3,a4],axis=3)
        print(b4)
结果:
Tensor("concat:0", shape=(100, 28, 28, 256), dtype=float32)

分别合并四个纬度,代码:

        a1 = tf.random_normal((100,28,28,128))
        a2 = tf.random_normal((100,28,28,128))
        a3 = tf.random_normal((100,28,28,128))
        a4 = tf.random_normal((100,28,28,128))

        b1 = tf.concat(values=[a1,a2,a3,a4],axis=0)
        b2 = tf.concat(values=[a1,a2,a3,a4],axis=1)
        b3 = tf.concat(values=[a1,a2,a3,a4],axis=2)
        b4 = tf.concat(values=[a1,a2,a3,a4],axis=3)


        print(b1,b2,b3,b4)

结果: 

Tensor("concat:0", shape=(400, 28, 28, 128), dtype=float32) 
Tensor("concat_1:0", shape=(100, 112, 28, 128), dtype=float32) 
Tensor("concat_2:0", shape=(100, 28, 112, 128), dtype=float32) 
Tensor("concat_3:0", shape=(100, 28, 28, 512), dtype=float32)

对应axis的维度,被合并了.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值