一、相对位置编码
相对位置编码是针对绝对位置编码的一种改进,旨在捕捉序列中元素之间的相对位置信息。相对位置编码在处理长距离依赖关系和泛化到不同长度的序列时表现更好。
1、工作原理
相对位置编码的核心思想是,位置关系是相对的而不是绝对的。相对位置编码将位置差异(相对位置)纳入注意力计算中。
假设有一个序列长度为 \( N \),位置 \( i \) 和位置 \( j \) 的相对位置编码可以表示为 \( r_{ij} \),并用于调整注意力得分 \( e_{ij} \)。
2、实现方式
(1)相对位置嵌入
使用一个相对位置嵌入矩阵 \( W_r \) 来表示位置差异。位置 \( i \) 和位置 \( j \) 的相对位置编码 \( r_{ij} \) 可以通过查询这个嵌入矩阵得到。
\[
r_{ij} = W_r[j - i + N - 1]
\]
其中 \( N \) 是序列的最大长度,确保索引 \( j - i + N - 1 \) 在有效范围内。
(2)相对位置编码与注意力计算
将相对位置编码 \( r_{ij} \) 加入注意力得分的计算中:
\[
e_{ij} = \frac{(Q_i K_j^T)}{\sqrt{d_k}} + r_{ij}
\]
其中 \( Q_i \) 和 \( K_j \) 分别是位置 \( i \) 和位置 \( j \) 的查询和键向量, \( d_k \) 是键向量的维度。
3、示例
假设我们有一个长度为 4 的序列 \( [a, b, c, d] \),我们希望计算位置 \( i \) 到位置 \( j \) 的相对位置编码:
\( r_{00} = W_r[4 - 1] \)
\( r_{01} = W_r[4] \)
\( r_{10} = W_r[2] \)
\( r_{12} = W_r[4] \)
位置嵌入矩阵 \( W_r \) 的维度为 \( 2N-1 \),即 \( W_r \) 的大小为 \( (2 \times 4 - 1) = 7 \)。
二、混合位置编码
混合位置编码结合了绝对位置编码和相对位置编码,利用两者的优点,既能捕捉全局位置信息,又能捕捉局部的相对位置信息。
1、工作原理
混合位置编码的基本思想是将绝对位置编码和相对位置编码结合起来,提供更加丰富的位置信息。这种方式可以通过简单的相加、拼接或其他融合方法实现。
2、实现方式
(1)相加方式
绝对位置编码 \( \text{PE}_{i} \) 和相对位置编码 \( r_{ij} \) 相加:
\[
e_{ij} = \frac{(Q_i K_j^T)}{\sqrt{d_k}} + \text{PE}_{i} + r_{ij}
\]
(2)拼接方式
将绝对位置编码 \( \text{PE}_{i} \) 和相对位置编码 \( r_{ij} \) 拼接成一个向量,然后通过线性变换结合:
\[
e_{ij} = \frac{(Q_i K_j^T)}{\sqrt{d_k}} + W_o [\text{PE}_{i}; r_{ij}]
\]
其中,\( W_o \) 是一个可训练的权重矩阵,\([;]\) 表示拼接操作。
3、示例
假设我们有一个长度为 4 的序列 \( [a, b, c, d] \):
绝对位置编码 \( \text{PE}_{i} \) 可以通过正弦和余弦函数计算。
相对位置编码 \( r_{ij} \) 通过相对位置嵌入矩阵 \( W_r \) 计算。
使用混合位置编码时,注意力得分的计算将结合这两种编码方式。例如:
对于位置 \( i = 1 \) 和位置 \( j = 2 \):
绝对位置编码 \( \text{PE}_{1} \)
相对位置编码 \( r_{12} \)
(1)相加方式:
\[
e_{12} = \frac{(Q_1 K_2^T)}{\sqrt{d_k}} + \text{PE}_{1} + r_{12}
\]
(2)拼接方式:
\[
e_{12} = \frac{(Q_1 K_2^T)}{\sqrt{d_k}} + W_o [\text{PE}_{1}; r_{12}]
\]
三、总结
相对位置编码:更适合捕捉元素之间的相对关系,有助于处理长距离依赖和泛化到不同长度的序列。
混合位置编码:结合绝对位置编码和相对位置编码的优点,提供丰富的位置信息,适用于需要同时捕捉全局和局部位置信息的任务。
这两种编码方式的结合和应用,可以显著提高Transformer模型在处理复杂序列任务中的性能。