为什么需要位置信息?
背景介绍
Transformer 是一种基于注意力机制的神经网络模型,广泛应用于自然语言处理任务,如机器翻译、文本生成等。与传统的循环神经网络(RNN)不同,Transformer 没有内置的序列顺序处理能力,因此需要一种方法来引入序列中元素的位置信息。
-
自注意力机制的特点:Transformer 的核心是自注意力机制(Self-Attention),它能够在序列中任意两个位置之间建立直接的依赖关系。但是,由于这种机制对序列中元素的位置不敏感,如果不引入位置信息,模型就无法区分不同位置的元素,导致序列信息的丢失。
-
位置信息的重要性:在自然语言处理中,词语的顺序对句子的含义有着重要影响。例如,“我爱你”和“你爱我”虽然包含相同的词,但顺序不同,含义也不同。因此,引入位置信息对于模型理解序列数据至关重要。
为什么使用正弦和余弦函数表示位置信息
Transformer 模型的作者 Vaswani 等人在论文 “Attention Is All You Need” 中提出了一种 位置编码(Positional Encoding) 方法,使用正弦和余弦函数来表示位置信息,其原因和优势如下:
1. 捕捉不同频率的位置信息
-
多频率表示:通过对不同维度使用不同频率的正弦和余弦函数,位置编码能够在不同的尺度上捕捉位置信息。这使得模型可以学习到序列中不同范围的位置信息。
-
公式表示:
对于序列中位置为 ( pos ) 的元素,第 ( i ) 个维度的位置编码为:
P E ( p o s , 2 i ) = sin ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)
P E ( p o s , 2 i + 1 ) = cos ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)
其中, d m o d e l d_{model} dmodel是模型的维度, i i i 是维度的索引。
2. 方便模型学习相对位置关系
- 线性可加性:正弦和余弦函数具有良好的数学性质,例如:
sin
(
a
+
b
)
=
sin
a
cos
b
+
cos
a
sin
b
\sin(a + b) = \sin a \cos b + \cos a \sin b
sin(a+b)=sinacosb+cosasinb
cos
(
a
+
b
)
=
cos
a
cos
b
−
sin
a
sin
b
\cos(a + b) = \cos a \cos b - \sin a \sin b
cos(a+b)=cosacosb−sinasinb
这意味着,位置编码中的位置差异可以通过向量之间的线性运算体现,模型更容易学习到序列中元素的相对位置信息。
3. 不依赖序列长度
- 可扩展性:由于位置编码是通过函数计算的,理论上可以适用于任意长度的序列,无需预先设定最大序列长度,也不需要在训练时学习额外的参数。
4. 平滑的位置信息
- 连续性:正弦和余弦函数生成的位置信息在位置上是连续变化的,能够更好地表示序列中相邻位置的关联性。
为什么不直接使用位置索引 pos 作为位置编码
-
数值范围和尺度差异
- 嵌入向量的数值范围:词嵌入向量通常是经过训练的实数值,其数值一般在较小的范围内,如 -1 到 1 或 0 到 1。
- 位置索引的数值范围:位置索引 pos 是整数,可能从 0 增加到数十、数百,甚至更大的值,具体取决于序列的长度。
- 尺度不匹配:直接将大的整数位置索引与词嵌入向量相加,可能导致位置信息在数值上压倒词嵌入信息,或者在反向传播过程中导致梯度不稳定。
-
缺乏平滑的位置信息
- 离散跳变:位置索引 pos 是离散的整数,随着位置增加,值会出现跳跃。
- 难以捕捉相对关系:模型难以学习到相邻位置之间的平滑变化和相对位置信息。
-
泛化能力受限
- 固定序列长度:如果训练时序列长度较短,模型可能无法泛化到更长的序列,因为未见过的较大位置索引 pos 会导致模型行为不可预测。
- 无法捕捉周期性和层次性模式:直接使用位置索引无法表示序列中可能存在的周期性或层次性结构。
-
数学性质不足
- 缺乏有利于模型学习的数学性质:位置索引 pos 不具备像正弦和余弦函数那样的周期性和线性性质,模型难以通过简单的线性操作来学习位置之间的关系。
正弦和余弦如何得到相对位置
正弦和余弦函数的加法公式如下:
- 正弦加法公式:
sin ( a + b ) = sin a cos b + cos a sin b \sin(a + b) = \sin a \cos b + \cos a \sin b sin(a+b)=sinacosb+cosasinb
- 余弦加法公式:
cos ( a + b ) = cos a cos b − sin a sin b \cos(a + b) = \cos a \cos b - \sin a \sin b cos(a+b)=cosacosb−sinasinb
在 Transformer 中,位置编码 P E PE PE 的计算方式为:
{ P E ( p o s , 2 i ) = sin ( p o s / 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ( p o s / 1000 0 2 i / d m o d e l ) \begin{cases} PE_{(pos, 2i)} = \sin\left(pos / 10000^{2i/d_{model}}\right) \\ PE_{(pos, 2i+1)} = \cos\left(pos / 10000^{2i/d_{model}}\right) \end{cases} {PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)
其中:
- p o s pos pos 表示序列中元素的位置索引。
- i i i 是维度索引。
- d m o d e l d_{model} dmodel 是模型的维度。
假设模型需要计算两个位置 p o s pos pos 和 $ pos’$ 之间的相对位置信息 p o s ′ − p o s pos' - pos pos′−pos 。由于位置编码使用了正弦和余弦函数,加上其加法公式的性质,模型可以通过对位置编码的线性操作,推导出相对位置。
1. 相对位置的编码表示
考虑位置编码向量 P E ( p o s ) PE(pos) PE(pos) 和 P E ( p o s ′ ) PE(pos') PE(pos′),它们的差异可以反映出位置 p o s pos pos 和 p o s ′ pos' pos′ 之间的相对位置信息。
2. 利用加法公式推导
假设我们关注某一维度,设频率 ω = 1 / 1000 0 2 i / d m o d e l \omega = 1/10000^{2i/d_{model}} ω=1/100002i/dmodel,则有:
- 对于位置 p o s pos pos:
{ sin ( ω p o s ) cos ( ω p o s ) \begin{cases} \sin(\omega pos) \\ \cos(\omega pos) \end{cases} {sin(ωpos)cos(ωpos)
- 对于位置 p o s ′ pos' pos′:
{ sin ( ω p o s ′ ) cos ( ω p o s ′ ) \begin{cases} \sin(\omega pos') \\ \cos(\omega pos') \end{cases} {sin(ωpos′)cos(ωpos′)
我们可以表示 p o s ′ pos' pos′ 的位置编码为:
{ sin ( ω p o s ′ ) = sin ( ω ( p o s + Δ p o s ) ) cos ( ω p o s ′ ) = cos ( ω ( p o s + Δ p o s ) ) \begin{cases} \sin(\omega pos') = \sin(\omega (pos + \Delta pos)) \\ \cos(\omega pos') = \cos(\omega (pos + \Delta pos)) \end{cases} {sin(ωpos′)=sin(ω(pos+Δpos))cos(ωpos′)=cos(ω(pos+Δpos))
其中 Δ p o s = p o s ′ − p o s \Delta pos = pos' - pos Δpos=pos′−pos 是相对位移。
3. 应用加法公式
应用正弦和余弦的加法公式:
- 正弦部分:
sin ( ω p o s ′ ) = sin ( ω p o s + ω Δ p o s ) = sin ( ω p o s ) cos ( ω Δ p o s ) + cos ( ω p o s ) sin ( ω Δ p o s ) \sin(\omega pos') = \sin(\omega pos + \omega \Delta pos) = \sin(\omega pos) \cos(\omega \Delta pos) + \cos(\omega pos) \sin(\omega \Delta pos) sin(ωpos′)=sin(ωpos+ωΔpos)=sin(ωpos)cos(ωΔpos)+cos(ωpos)sin(ωΔpos)
- 余弦部分:
cos ( ω p o s ′ ) = cos ( ω p o s + ω Δ p o s ) = cos ( ω p o s ) cos ( ω Δ p o s ) − sin ( ω p o s ) sin ( ω Δ p o s ) \cos(\omega pos') = \cos(\omega pos + \omega \Delta pos) = \cos(\omega pos) \cos(\omega \Delta pos) - \sin(\omega pos) \sin(\omega \Delta pos) cos(ωpos′)=cos(ωpos+ωΔpos)=cos(ωpos)cos(ωΔpos)−sin(ωpos)sin(ωΔpos)
4. 线性组合形式
将上面的结果重新排列:
- 将 sin ( ω p o s ′ ) \sin(\omega pos') sin(ωpos′) 表示为 sin ( ω p o s ) \sin(\omega pos) sin(ωpos) 和 cos ( ω p o s ) \cos(\omega pos) cos(ωpos) 的线性组合:
sin ( ω p o s ′ ) = [ cos ( ω Δ p o s ) ] sin ( ω p o s ) + [ sin ( ω Δ p o s ) ] cos ( ω p o s ) \sin(\omega pos') = [\cos(\omega \Delta pos)] \sin(\omega pos) + [\sin(\omega \Delta pos)] \cos(\omega pos) sin(ωpos′)=[cos(ωΔpos)]sin(ωpos)+[sin(ωΔpos)]cos(ωpos)
- 将 cos ( ω p o s ′ ) \cos(\omega pos') cos(ωpos′) 表示为 sin ( ω p o s ) \sin(\omega pos) sin(ωpos) 和 cos ( ω p o s ) \cos(\omega pos) cos(ωpos) 的线性组合:
cos ( ω p o s ′ ) = [ cos ( ω Δ p o s ) ] cos ( ω p o s ) − [ sin ( ω Δ p o s ) ] sin ( ω p o s ) \cos(\omega pos') = [\cos(\omega \Delta pos)] \cos(\omega pos) - [\sin(\omega \Delta pos)] \sin(\omega pos) cos(ωpos′)=[cos(ωΔpos)]cos(ωpos)−[sin(ωΔpos)]sin(ωpos)
这意味着,位置 p o s ′ pos' pos′ 的编码可以通过位置 p o s pos pos 的编码通过 线性组合 来表示,其中系数是 cos ( ω Δ p o s ) \cos(\omega \Delta pos) cos(ωΔpos) 和 sin ( ω Δ p o s ) \sin(\omega \Delta pos) sin(ωΔpos),只与相对位移 Δ p o s \Delta pos Δpos 有关。