tf.concat()
作用:拼接张量,主要参数为values,axis;
values为需要合并的list
axis决定了这些list中的张量如何合并,axis=0,合并第一维,axis=1,合并第二维,axis=2合并第三维,axis=3合并第四维度,前提是,除了需要合并的那一位纬度,其他不合并的纬度属性必须相同;
axis=3时,四个张量最后一维不同,合并第四维,代码:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
a1 = tf.random_normal((100,28,28,64))
a2 = tf.random_normal((100,28,28,128))
a3 = tf.random_normal((100,28,28,32))
a4 = tf.random_normal((100,28,28,32))
b4 = tf.concat(values=[a1,a2,a3,a4],axis=3)
print(b4)
结果:
Tensor("concat:0", shape=(100, 28, 28, 256), dtype=float32)
分别合并四个纬度,代码:
a1 = tf.random_normal((100,28,28,128))
a2 = tf.random_normal((100,28,28,128))
a3 = tf.random_normal((100,28,28,128))
a4 = tf.random_normal((100,28,28,128))
b1 = tf.concat(values=[a1,a2,a3,a4],axis=0)
b2 = tf.concat(values=[a1,a2,a3,a4],axis=1)
b3 = tf.concat(values=[a1,a2,a3,a4],axis=2)
b4 = tf.concat(values=[a1,a2,a3,a4],axis=3)
print(b1,b2,b3,b4)
结果:
Tensor("concat:0", shape=(400, 28, 28, 128), dtype=float32)
Tensor("concat_1:0", shape=(100, 112, 28, 128), dtype=float32)
Tensor("concat_2:0", shape=(100, 28, 112, 128), dtype=float32)
Tensor("concat_3:0", shape=(100, 28, 28, 512), dtype=float32)
对应axis的维度,被合并了.