transformer为什么使用sin和cos表示位置信息

为什么需要位置信息?

背景介绍

Transformer 是一种基于注意力机制的神经网络模型,广泛应用于自然语言处理任务,如机器翻译、文本生成等。与传统的循环神经网络(RNN)不同,Transformer 没有内置的序列顺序处理能力,因此需要一种方法来引入序列中元素的位置信息。

  • 自注意力机制的特点:Transformer 的核心是自注意力机制(Self-Attention),它能够在序列中任意两个位置之间建立直接的依赖关系。但是,由于这种机制对序列中元素的位置不敏感,如果不引入位置信息,模型就无法区分不同位置的元素,导致序列信息的丢失。

  • 位置信息的重要性:在自然语言处理中,词语的顺序对句子的含义有着重要影响。例如,“我爱你”和“你爱我”虽然包含相同的词,但顺序不同,含义也不同。因此,引入位置信息对于模型理解序列数据至关重要。


为什么使用正弦和余弦函数表示位置信息

Transformer 模型的作者 Vaswani 等人在论文 “Attention Is All You Need” 中提出了一种 位置编码(Positional Encoding) 方法,使用正弦和余弦函数来表示位置信息,其原因和优势如下:

1. 捕捉不同频率的位置信息

  • 多频率表示:通过对不同维度使用不同频率的正弦和余弦函数,位置编码能够在不同的尺度上捕捉位置信息。这使得模型可以学习到序列中不同范围的位置信息。

  • 公式表示

    对于序列中位置为 ( pos ) 的元素,第 ( i ) 个维度的位置编码为:

P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)

P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)

其中, d m o d e l d_{model} dmodel是模型的维度, i i i 是维度的索引。

2. 方便模型学习相对位置关系
  • 线性可加性:正弦和余弦函数具有良好的数学性质,例如:

sin ⁡ ( a + b ) = sin ⁡ a cos ⁡ b + cos ⁡ a sin ⁡ b \sin(a + b) = \sin a \cos b + \cos a \sin b sin(a+b)=sinacosb+cosasinb
cos ⁡ ( a + b ) = cos ⁡ a cos ⁡ b − sin ⁡ a sin ⁡ b \cos(a + b) = \cos a \cos b - \sin a \sin b cos(a+b)=cosacosbsinasinb

这意味着,位置编码中的位置差异可以通过向量之间的线性运算体现,模型更容易学习到序列中元素的相对位置信息。

3. 不依赖序列长度
  • 可扩展性:由于位置编码是通过函数计算的,理论上可以适用于任意长度的序列,无需预先设定最大序列长度,也不需要在训练时学习额外的参数。
4. 平滑的位置信息
  • 连续性:正弦和余弦函数生成的位置信息在位置上是连续变化的,能够更好地表示序列中相邻位置的关联性。

为什么不直接使用位置索引 pos 作为位置编码

  1. 数值范围和尺度差异

    • 嵌入向量的数值范围:词嵌入向量通常是经过训练的实数值,其数值一般在较小的范围内,如 -1 到 1 或 0 到 1。
    • 位置索引的数值范围:位置索引 pos 是整数,可能从 0 增加到数十、数百,甚至更大的值,具体取决于序列的长度。
    • 尺度不匹配:直接将大的整数位置索引与词嵌入向量相加,可能导致位置信息在数值上压倒词嵌入信息,或者在反向传播过程中导致梯度不稳定。
  2. 缺乏平滑的位置信息

    • 离散跳变:位置索引 pos 是离散的整数,随着位置增加,值会出现跳跃。
    • 难以捕捉相对关系:模型难以学习到相邻位置之间的平滑变化和相对位置信息。
  3. 泛化能力受限

    • 固定序列长度:如果训练时序列长度较短,模型可能无法泛化到更长的序列,因为未见过的较大位置索引 pos 会导致模型行为不可预测。
    • 无法捕捉周期性和层次性模式:直接使用位置索引无法表示序列中可能存在的周期性或层次性结构。
  4. 数学性质不足

    • 缺乏有利于模型学习的数学性质:位置索引 pos 不具备像正弦和余弦函数那样的周期性和线性性质,模型难以通过简单的线性操作来学习位置之间的关系。

正弦和余弦如何得到相对位置

正弦和余弦函数的加法公式如下:

  • 正弦加法公式

sin ⁡ ( a + b ) = sin ⁡ a cos ⁡ b + cos ⁡ a sin ⁡ b \sin(a + b) = \sin a \cos b + \cos a \sin b sin(a+b)=sinacosb+cosasinb

  • 余弦加法公式

cos ⁡ ( a + b ) = cos ⁡ a cos ⁡ b − sin ⁡ a sin ⁡ b \cos(a + b) = \cos a \cos b - \sin a \sin b cos(a+b)=cosacosbsinasinb

在 Transformer 中,位置编码 P E PE PE 的计算方式为:

{ P E ( p o s , 2 i ) = sin ⁡ ( p o s / 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 1000 0 2 i / d m o d e l ) \begin{cases} PE_{(pos, 2i)} = \sin\left(pos / 10000^{2i/d_{model}}\right) \\ PE_{(pos, 2i+1)} = \cos\left(pos / 10000^{2i/d_{model}}\right) \end{cases} {PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)

其中:

  • p o s pos pos 表示序列中元素的位置索引。
  • i i i 是维度索引。
  • d m o d e l d_{model} dmodel 是模型的维度。

假设模型需要计算两个位置 p o s pos pos 和 $ pos’$ 之间的相对位置信息 p o s ′ − p o s pos' - pos pospos 。由于位置编码使用了正弦和余弦函数,加上其加法公式的性质,模型可以通过对位置编码的线性操作,推导出相对位置。

1. 相对位置的编码表示

考虑位置编码向量 P E ( p o s ) PE(pos) PE(pos) P E ( p o s ′ ) PE(pos') PE(pos),它们的差异可以反映出位置 p o s pos pos p o s ′ pos' pos 之间的相对位置信息。

2. 利用加法公式推导

假设我们关注某一维度,设频率 ω = 1 / 1000 0 2 i / d m o d e l \omega = 1/10000^{2i/d_{model}} ω=1/100002i/dmodel,则有:

  • 对于位置 p o s pos pos

{ sin ⁡ ( ω p o s ) cos ⁡ ( ω p o s ) \begin{cases} \sin(\omega pos) \\ \cos(\omega pos) \end{cases} {sin(ωpos)cos(ωpos)

  • 对于位置 p o s ′ pos' pos

{ sin ⁡ ( ω p o s ′ ) cos ⁡ ( ω p o s ′ ) \begin{cases} \sin(\omega pos') \\ \cos(\omega pos') \end{cases} {sin(ωpos)cos(ωpos)

我们可以表示 p o s ′ pos' pos 的位置编码为:

{ sin ⁡ ( ω p o s ′ ) = sin ⁡ ( ω ( p o s + Δ p o s ) ) cos ⁡ ( ω p o s ′ ) = cos ⁡ ( ω ( p o s + Δ p o s ) ) \begin{cases} \sin(\omega pos') = \sin(\omega (pos + \Delta pos)) \\ \cos(\omega pos') = \cos(\omega (pos + \Delta pos)) \end{cases} {sin(ωpos)=sin(ω(pos+Δpos))cos(ωpos)=cos(ω(pos+Δpos))

其中 Δ p o s = p o s ′ − p o s \Delta pos = pos' - pos Δpos=pospos 是相对位移。

3. 应用加法公式

应用正弦和余弦的加法公式:

  • 正弦部分

sin ⁡ ( ω p o s ′ ) = sin ⁡ ( ω p o s + ω Δ p o s ) = sin ⁡ ( ω p o s ) cos ⁡ ( ω Δ p o s ) + cos ⁡ ( ω p o s ) sin ⁡ ( ω Δ p o s ) \sin(\omega pos') = \sin(\omega pos + \omega \Delta pos) = \sin(\omega pos) \cos(\omega \Delta pos) + \cos(\omega pos) \sin(\omega \Delta pos) sin(ωpos)=sin(ωpos+ωΔpos)=sin(ωpos)cos(ωΔpos)+cos(ωpos)sin(ωΔpos)

  • 余弦部分

cos ⁡ ( ω p o s ′ ) = cos ⁡ ( ω p o s + ω Δ p o s ) = cos ⁡ ( ω p o s ) cos ⁡ ( ω Δ p o s ) − sin ⁡ ( ω p o s ) sin ⁡ ( ω Δ p o s ) \cos(\omega pos') = \cos(\omega pos + \omega \Delta pos) = \cos(\omega pos) \cos(\omega \Delta pos) - \sin(\omega pos) \sin(\omega \Delta pos) cos(ωpos)=cos(ωpos+ωΔpos)=cos(ωpos)cos(ωΔpos)sin(ωpos)sin(ωΔpos)

4. 线性组合形式

将上面的结果重新排列:

  • sin ⁡ ( ω p o s ′ ) \sin(\omega pos') sin(ωpos) 表示为 sin ⁡ ( ω p o s ) \sin(\omega pos) sin(ωpos) cos ⁡ ( ω p o s ) \cos(\omega pos) cos(ωpos) 的线性组合

sin ⁡ ( ω p o s ′ ) = [ cos ⁡ ( ω Δ p o s ) ] sin ⁡ ( ω p o s ) + [ sin ⁡ ( ω Δ p o s ) ] cos ⁡ ( ω p o s ) \sin(\omega pos') = [\cos(\omega \Delta pos)] \sin(\omega pos) + [\sin(\omega \Delta pos)] \cos(\omega pos) sin(ωpos)=[cos(ωΔpos)]sin(ωpos)+[sin(ωΔpos)]cos(ωpos)

  • cos ⁡ ( ω p o s ′ ) \cos(\omega pos') cos(ωpos) 表示为 sin ⁡ ( ω p o s ) \sin(\omega pos) sin(ωpos) cos ⁡ ( ω p o s ) \cos(\omega pos) cos(ωpos) 的线性组合

cos ⁡ ( ω p o s ′ ) = [ cos ⁡ ( ω Δ p o s ) ] cos ⁡ ( ω p o s ) − [ sin ⁡ ( ω Δ p o s ) ] sin ⁡ ( ω p o s ) \cos(\omega pos') = [\cos(\omega \Delta pos)] \cos(\omega pos) - [\sin(\omega \Delta pos)] \sin(\omega pos) cos(ωpos)=[cos(ωΔpos)]cos(ωpos)[sin(ωΔpos)]sin(ωpos)

这意味着,位置 p o s ′ pos' pos 的编码可以通过位置 p o s pos pos 的编码通过 线性组合 来表示,其中系数是 cos ⁡ ( ω Δ p o s ) \cos(\omega \Delta pos) cos(ωΔpos) sin ⁡ ( ω Δ p o s ) \sin(\omega \Delta pos) sin(ωΔpos),只与相对位移 Δ p o s \Delta pos Δpos 有关。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值