tf.concat([tensor1, tensor2, tensor3,...], axis)
tf.concat()的功能是将输入参数中的tensor1, tensor2, tensor3,...进行拼接,看了一些关于拼接的解释,不是特别直观的理解,再加上案例事件,个人总结拼接的规则是,axis指的是shape向量的下标索引,来指示合并维度的:
例如4个tensorflow的shape为[n,h,w,c]的张量t1,t2,t3,t4,
当执行t5 = tf.concate([t1,t2,t3,t4],0)表示对shape向量的下标0的维度进行合并,也就是对shape的维度n进行合并,t5的shape为[4*n,h,w,c]
当执行t5 = tf.concate([t1,t2,t3,t4],1)表示对shape向量的下标1的维度进行合并,也就是对shape的维度h进行合并,t5的shape为[n,4*h,w,c]
当执行t5 = tf.concate([t1,t2,t3,t4],2)表示对shape向量的下标2的维度进行合并,也就是对shape的维度w进行合并,t5的shape为[n,h,4*w,c]
当执行t5 = tf.concate([t1,t2,t3,t4],3)表示对shape向量的下标3的维度进行合并,也就是对shape的维度w进行合并,t5的shape为[n,h,w,4*c]