tf.concat实例用法与图解

tf.concat的参数为:

tf.concat(values, axis, name='concat')

其中的axis参数并不直观。

下面的代码是官网给出的示例代码。

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0)  # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1)  # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]

个人认为可以这样来理解记忆:

t1t2都是一个batch的向量,每个向量是一个样本,例如图像用CNN或者文本用RNN编码后的hidden state。设t1t2分别为 m 1 × n 1 m_1\times n_1 m1×n1 m 2 × n 2 m_2\times n_2 m2×n2的矩阵。那么, m m m代表的是batch_size, n n n代表的是每个样本的特征数量。本文的向量按照行向量来画图理解。

axis=0就是将两个batch的向量取出来放到同一个batch中,组成batch_size更大的batch。如下图所示(图中每种颜色的小圆圈代表一种特征)。注意,两个batch的向量的特征必须相等,即向量的维度大小和含义都相等,即图中每个向量的圆圈个数相同且对应位置的圆圈颜色相同;而batch_size不必相等(当然也可以相等),例如图中一个batch有3个向量,另一个batch有4个向量。

在这里插入图片描述

axis=1就是将两个batch相同位置的向量取出来进行两个向量的拼接,然后作为结果的一个向量。如下图所示。注意,两个batch的batch_size必须相等(这样才能组成对);而两个batch的向量的特征可以不同(当然也可以相同)。

在这里插入图片描述

需要使用axis=1的典型的情景,是我们用不同的编码器对同一批数据进行编码,得到两个batch_size相等而特征不相等(至少含义不相等)的batch,这时我们可以将其按照上图的方式成对地拼接起来得到一个对于这批数据更全面的编码结果,传给后续的网络。更具体的例子是,使用Multi-Head Self-Attention,每个Head都会得到同一个句子的不同编码矩阵,而我们可以将其拼接起来。(下面两张图的图片来源

在这里插入图片描述

在这里插入图片描述

两个以上的矩阵进行concat也同理。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值