tf.concat是连接两个矩阵的操作,tf.concat(values,dim,name='concat')
按照dim给定的维度进行拼接,即,相应的维度增加,例子如下:
矩阵维度简单情形(shape为[2,3])
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
拼接后结果:
tf.concat([t1, t2], 0) # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
对拼接的结果shape
tf.shape(tf.concat([t1, t2], 0)) #新的维度shape [4, 3]
tf.shape(tf.concat([t1, t2], 1)) #新的维度shape [2, 6]
这里解释了当axis=0和axis=1的情况,怎么理解这个axis呢?其实这和numpy中的np.concatenate()用法是一样的。
axis=0 代表在第0个维度拼接
axis=1 代表在第1个维度拼接
矩阵都是2*2*2维度的情形
t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
tf.concat([t1, t2], axis=-1)
-1表示最后一个维度,最后一个维度增加
输出结果为
<tf.Tensor 'concat_2:0' shape=(2, 2, 4) dtype=int32>