tf.concat
的参数为:
tf.concat(values, axis, name='concat')
其中的axis参数并不直观。
下面的代码是官网给出的示例代码。
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
个人认为可以这样来理解记忆:
t1
和t2
都是一个batch的向量,每个向量是一个样本,例如图像用CNN或者文本用RNN编码后的hidden state。设t1
和t2
分别为
m
1
×
n
1
m_1\times n_1
m1×n1和
m
2
×
n
2
m_2\times n_2
m2×n2的矩阵。那么,
m
m
m代表的是batch_size,
n
n
n代表的是每个样本的特征数量。本文的向量按照行向量来画图理解。
axis=0
就是将两个batch的向量取出来放到同一个batch中,组成batch_size更大的batch。如下图所示(图中每种颜色的小圆圈代表一种特征)。注意,两个batch的向量的特征必须相等,即向量的维度大小和含义都相等,即图中每个向量的圆圈个数相同且对应位置的圆圈颜色相同;而batch_size不必相等(当然也可以相等),例如图中一个batch有3个向量,另一个batch有4个向量。
axis=1
就是将两个batch相同位置的向量取出来进行两个向量的拼接,然后作为结果的一个向量。如下图所示。注意,两个batch的batch_size必须相等(这样才能组成对);而两个batch的向量的特征可以不同(当然也可以相同)。
需要使用axis=1
的典型的情景,是我们用不同的编码器对同一批数据进行编码,得到两个batch_size相等而特征不相等(至少含义不相等)的batch,这时我们可以将其按照上图的方式成对地拼接起来得到一个对于这批数据更全面的编码结果,传给后续的网络。更具体的例子是,使用Multi-Head Self-Attention,每个Head都会得到同一个句子的不同编码矩阵,而我们可以将其拼接起来。(下面两张图的图片来源)
两个以上的矩阵进行concat也同理。