tf.concat()函数介绍和示例
tf.concat([tensor1, tensor2, tensor3, … ], axis)
释义:将张量 tensor1, tensor2, tensor3, … 拼接
- axis,指定轴(维度)
示例:
import tensorflow as tf
t1 = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
t2 = tf.constant([[7, 8, 9], [10, 11, 12]], dtype=tf.float32)
T0 = tf.concat([t1, t2], axis=0) # 维度 0 处拼接,即列拼接
T1 = tf.concat([t1, t2], axis=1) # 维度 1 处拼接,即行拼接
with tf.Session() as sess:
print('维度 0 处拼接,即列拼接:\n', sess.run(T0))
print('='*30)
print('维度 1 处拼接,即行拼接:\n', sess.run(T1))
维度 0 处拼接,即列拼接:
[[ 1. 2. 3.]
[ 4. 5. 6.]
[ 7. 8. 9.]
[10. 11. 12.]]
==============================
维度 1 处拼接,即行拼接:
[[ 1. 2. 3. 7. 8. 9.]
[ 4. 5. 6. 10. 11. 12.]]