1、拼接张量:tf.concat()
用法:tf.concat([tensor1, tensor2, tensor3,…], axis, name=‘concat’)
tf.concat()拼接的张量只会改变一个维度上元素大小,其他维度的元素大小是保存不变,用axis的设置不同维度进行拼接。
In [1]: import tensorflow as tf
In [2]: a = tf.ones([4, 35, 8])
In [3]: b = tf.ones([2, 35, 8])
In [4]: c = tf.concat([a, b], axis=0)
In [5]: c.shape
Out [5]: TensorShape([6, 35, 8])
In [6]: import tensorflow as tf
In [7]: a = tf.ones([4, 32, 8])
In [8]: b = tf.ones([4, 3, 8])
In [9]: c = tf.concat([a, b], axis=0).shape
Out [9]: TensorShape([6, 35, 8])
注意:张量进行拼接时仅拼接的那个维度axis的元素大小可不同,其余维度上的元素大小必须相同才可进行拼接。
In [1]: import tensorflow as tf
In [2]: a = tf.ones([4, 35, 8])
In [3]: b = tf.ones([3, 33, 8])
In [4]: tf.concat([a, b], axis=0)
InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [4,35,8] vs. shape[1] = [3,33,8] [Op