Rotary Position Embedding算法原理
RoPE算法背景
在《Attention is all your need》 中使用了absolute position embedding,且在论文中验证了嵌入位置向量是有效的,所以位置向量的概念渐渐的引入人们的视角里。目前主要的向量编码有俩种,分别是绝对位置编码和相对位置编码。其中绝对位置编码的优点就是直观,可以将token的位置信息直接嵌入,但缺点也比较明显的,长度固定,且在计算时不好考虑token相对位置之间的关系。同理,相对位置编码的优势就是能够兼容变长序列且能够考虑到不同token相对位置的关系,很好地解决了绝对位置的问题,但它的缺点就是计算资源需求变大。Rotary Position Embedding同时考虑了相对位置和绝对位置优势,提出了一种即考虑到相对位置又能包含绝对位置的位置编码方式。
RoPE算法原理
寻找相对位置关系
在做attention计算中, Q T V Q^{T}V QTV中会涉及多个token之间的内积计算,那么对元素 x m x_{m} xm和元素 x n x_{n} xn之间做内积时如何得到考虑到 m − n m-n m−n这个相对位置的关系。作者将这个问题归纳为如下公式
< f q ( x m , m ) , f k ( x n , n ) > = g ( x m , x n , m − n ) <f_{q}(x_{m},m),f_{k}(x_{n},n)>=g(x_{m},x_{n},m-n) <fq(xm,m),fk(xn,n)>=g(xm,xn,m−n)
这里作者先讨论了 x x x为二维向量的情况,将二维向量的几何性质和复数的计算考虑进去。给定俩个二维向量 x q x_q xq和 x k x_k xk,所以有
f q ( x q , m ) = R q ( x q , m ) e i Θ q ( x q , m ) f k ( x k , n ) = R k ( x k , n ) e i Θ k ( x k , n ) g ( x q , x k , m − n ) = R e ( l ( x q , x k , m − n ) ) = R e ( R g ( x q , x k , m − n ) e i Θ g ( x q , x k , m − n ) ) \begin{align} f_{q}(x_q,m) &= R_{q}(x_q,m)e^{i\Theta_q(x_q,m)} \\ f_{k}(x_k,n) &= R_{k}(x_k,n)e^{i\Theta_k(x_k,n)}\\ g(x_q,x_k,m-n) & = Re(l(x_q,x_k,m-n))=Re(R_g(x_q,x_k,m-n)e^{i\Theta_g(x_q,x_k,m-n)}) \end{align} fq(xq,m)fk(xk,n)g(xq,xk,m−n)=Rq(xq,m)eiΘq(xq,m)=Rk(xk,n)eiΘk(xk,n)=Re(l(xq,xk,m−n))=Re(Rg(xq,xk,m−n)eiΘg(xq,xk,m−n))根据二维向量的性质有
R q ( x q , m ) = ( q m 1 q m 2 ) = q m 1 + i q m 2 R_q(x_q,m)=\left (\begin{matrix} q_{m}^{1} \\ q_{m}^{2} \end{matrix} \right) = q_{m}^{1}+i q_{m}^{2} Rq(xq,m)=(qm1qm2)=qm1+iqm2根据欧拉公式有
e i Θ q ( x q , m ) = cos ( Θ q ( x q , m ) ) + i sin ( Θ q ( x q , m ) ) e^{i\Theta_q(x_q,m)} = \cos(\Theta_q(x_q,m))+i \sin(\Theta_q(x_q,m)) eiΘq(xq,m)=cos(Θq(xq,m))+isin(Θq(xq,m))然后有 f q ( x q , m ) = ( q m 1 + i q m 2 ) ( cos ( Θ q ( x q , m ) ) + i sin ( Θ q ( x q , m ) ) ) = q m 1 cos ( Θ q ( x q , m ) ) − q m 2 sin ( Θ q ( x q , m ) ) + i ( q m 1 sin ( Θ q ( x q , m ) ) + q m 2 cos ( Θ q ( x q , m ) ) ) = ( q m 1 cos ( Θ q ( x q , m ) ) − q m 2 sin ( Θ q ( x q , m ) ) q m 1 sin ( Θ q ( x q , m ) ) + q m 2 cos ( Θ q ( x q , m ) ) ) = ( cos ( Θ q ( x q , m ) ) − sin ( Θ q ( x q , m ) ) sin ( Θ q ( x q , m ) ) cos ( Θ q ( x q , m ) ) ) ( q m 1 q m 2 ) \begin{align*} f_q(x_q,m) &= (q_{m}^{1}+i q_{m}^{2})(\cos(\Theta_q(x_q,m))+i \sin(\Theta_q(x_q,m)))\\ &=q_m^1\cos(\Theta_q(x_q,m))-q_m^2\sin(\Theta_q(x_q,m)) +i(q_m^1 \sin(\Theta_q(x_q,m))+q_m^2\cos(\Theta_q(x_q,m))) \\ &=\left (\begin{matrix}q_m^1\cos(\Theta_q(x_q,m))-q_m^2\sin(\Theta_q(x_q,m)) \\q_m^1 \sin(\Theta_q(x_q,m))+q_m^2\cos(\Theta_q(x_q,m)) \end{matrix} \right) \\ &=\left (\begin{matrix}\cos(\Theta_q(x_q,m)) &-\sin(\Theta_q(x_q,m)) \\ \sin(\Theta_q(x_q,m)) & \cos(\Theta_q(x_q,m)) \end{matrix} \right)\left (\begin{matrix}q_m^1\\ q_m^2\end{matrix} \right) \end{align*} fq(xq,m)=(qm1+iqm2)(cos(Θq(xq,m))+isin(Θq(xq,m)))=qm1cos(Θq(xq,m))−qm2sin(Θq(xq,m))+i(qm1sin(Θq(xq,m))+qm2cos(Θq(xq,m)))=(qm1cos(Θq(xq,m))−qm2sin(Θq(xq,m))qm1sin(Θq(xq,m))+qm2cos(Θq(xq,m)))=(cos(Θq(xq,m))sin(Θq(xq,m))−sin(Θq(xq,m))cos(Θq(xq,m)))(qm1qm2)同理 f k ( x k , n ) = ( cos ( Θ k ( x k , n ) ) − sin ( Θ k ( x k , n ) ) sin ( Θ k ( x k , m ) ) cos ( Θ k ( x k , m ) ) ) ( k n 1 k n 2 ) f_k(x_k,n) =\left (\begin{matrix}\cos(\Theta_k(x_k,n)) &-\sin(\Theta_k(x_k,n)) \\ \sin(\Theta_k(x_k,m)) & \cos(\Theta_k(x_k,m)) \end{matrix} \right)\left (\begin{matrix}k_n^1\\ k_n^2\end{matrix} \right) fk(xk,n)=(cos(Θk(xk,n))sin(Θk(xk,m))−sin(Θk(xk,n))cos(Θk(xk,m)))(kn1kn2)然后可以得出 < f q ( x q , m ) , f k ( x k , n ) > = ( ( cos ( Θ q ( x q , m ) ) − sin ( Θ q ( x q , m ) ) sin ( Θ q ( x q , m ) ) cos ( Θ q ( x q , m ) ) ) ( q m 1 q m 2 ) ) T ( cos ( Θ k ( x k , n ) ) − sin ( Θ k ( x k , n ) ) sin ( Θ k ( x k , m ) ) cos ( Θ k ( x k , m ) ) ) ( k n 1 k n 2 ) = ( q m 1 q m 2 ) ( cos ( Θ q ( x q , m ) ) sin ( Θ q ( x q , m ) ) − sin ( Θ q ( x q , m ) ) cos ( Θ q ( x q , m ) ) ) ( cos ( Θ k ( x k , n ) ) − sin ( Θ k ( x k , n ) ) sin ( Θ k ( x k , n ) ) cos ( Θ k ( x k , n ) ) ) ( k n 1 k n 2 ) = ( q m 1 q m 2 ) ( cos ( Θ q ( x q , m ) ) cos ( Θ k ( x k , n ) ) + sin ( Θ q ( x q , m ) ) sin ( Θ k ( x k , n ) ) sin ( Θ q ( x q , m ) ) cos ( Θ k ( x k , n ) ) − sin ( Θ k ( x k , n ) ) cos ( Θ q ( x q , m ) ) sin ( Θ k ( x k , n ) ) cos ( Θ q ( x q , m ) ) − sin ( Θ q ( x q , m ) ) cos ( Θ k ( x k , n ) ) cos ( Θ q ( x q , m ) ) cos ( Θ k ( x k , n ) ) + sin ( Θ q ( x q , m ) ) sin ( Θ k ( x k , n ) ) ) ( k n 1 k n 2 ) = ( q m 1 q m 2 ) ( cos ( Θ k ( x k , n ) − Θ q ( x q , m ) ) − sin ( Θ k ( x k , n ) − Θ q ( x q , m ) ) sin ( Θ k ( x k , n ) − Θ q ( x q , m ) ) cos ( Θ k ( x k , n ) − Θ q ( x q , m ) ) ) ( k n 1 k n 2 ) = ( q m 1 ( cos ( Θ k ( x k , n ) − Θ q ( x q , m ) ) ) + q m 2 ( sin ( Θ k ( x k , n ) − Θ q ( x q , m ) ) ) − q m 1 ( sin ( Θ k ( x k , n ) − Θ q ( x q , m ) ) ) + q m 2 ( cos ( Θ k ( x k , n ) − Θ q ( x q , m ) ) ) ) ( k n 1 k n 2 ) = ( q m 1 k n 1 + q m 2 k m 2 ) cos ( Θ k ( x k , n ) − Θ q ( x q , m ) ) + ( q m 2 k n 1 − q m 1 k n 2 ) sin ( Θ k ( x k , n ) − Θ q ( x q , m ) ) \begin{align*}<f_q(x_q,m) , f_k(x_k,n)>& = \left(\left (\begin{matrix}\cos(\Theta_q(x_q,m)) &-\sin(\Theta_q(x_q,m)) \\ \sin(\Theta_q(x_q,m)) & \cos(\Theta_q(x_q,m)) \end{matrix} \right)\left (\begin{matrix}q_m^1\\ q_m^2\end{matrix} \right)\right)^{T}\left (\begin{matrix}\cos(\Theta_k(x_k,n)) &-\sin(\Theta_k(x_k,n)) \\ \sin(\Theta_k(x_k,m)) & \cos(\Theta_k(x_k,m)) \end{matrix} \right)\left (\begin{matrix}k_n^1\\ k_n^2\end{matrix} \right) \\ & =\left(\begin{matrix} q_m^1 & q_m^2 \end{matrix}\right)\left(\begin{matrix}\cos(\Theta_q(x_q,m)) & \sin(\Theta_q(x_q,m)) \\ -\sin(\Theta_q(x_q,m)) & \cos(\Theta_q(x_q,m)) \end{matrix}\right)\left (\begin{matrix}\cos(\Theta_k(x_k,n)) &-\sin(\Theta_k(x_k,n)) \\ \sin(\Theta_k(x_k,n)) & \cos(\Theta_k(x_k,n)) \end{matrix} \right)\left (\begin{matrix}k_n^1\\ k_n^2\end{matrix} \right) \\ &=\left(\begin{matrix} q_m^1 & q_m^2 \end{matrix}\right)\left (\begin{matrix}\cos(\Theta_q(x_q,m))\cos(\Theta_k(x_k,n))+\sin(\Theta_q(x_q,m))\sin(\Theta_k(x_k,n)) & \sin(\Theta_q(x_q,m))\cos(\Theta_k(x_k,n))-\sin(\Theta_k(x_k,n))\cos(\Theta_q(x_q,m))\\ \sin(\Theta_k(x_k,n)) \cos(\Theta_q(x_q,m)) -\sin(\Theta_q(x_q,m))\cos(\Theta_k(x_k,n)) & \cos(\Theta_q(x_q,m))\cos(\Theta_k(x_k,n))+\sin(\Theta_q(x_q,m))\sin(\Theta_k(x_k,n)) \end{matrix} \right) \left (\begin{matrix}k_n^1\\ k_n^2 \end{matrix} \right) \\ & = \left(\begin{matrix} q_m^1 & q_m^2 \end{matrix}\right) \left (\begin{matrix} \cos(\Theta_k(x_k,n)-\Theta_q(x_q,m)) &-\sin(\Theta_k(x_k,n)-\Theta_q(x_q,m)) \\ \sin(\Theta_k(x_k,n)-\Theta_q(x_q,m)) &\cos(\Theta_k(x_k,n)-\Theta_q(x_q,m)) \end{matrix}\right)\left (\begin{matrix}k_n^1\\ k_n^2 \end{matrix} \right)\\ &= \left(\begin{matrix} q_m^1(\cos(\Theta_k(x_k,n)-\Theta_q(x_q,m)))+q_m^2(\sin(\Theta_k(x_k,n)-\Theta_q(x_q,m))) & -q_m^1(\sin(\Theta_k(x_k,n)-\Theta_q(x_q,m)))+q_m^2(\cos(\Theta_k(x_k,n)-\Theta_q(x_q,m)) )\end{matrix} \right)\left (\begin{matrix}k_n^1\\ k_n^2 \end{matrix} \right) \\ & = (q_m^1k_n^1+q_m^2k_m^2)\cos(\Theta_k(x_k,n)-\Theta_q(x_q,m))+(q_m^2k_n^1-q_m^1k_n^2)\sin(\Theta_k(x_k,n)-\Theta_q(x_q,m)) \end{align*} <fq(xq,m),fk(xk,n)>=((cos(Θq(xq,m))sin(Θq(xq,m))−sin(Θq(xq,m))cos(Θq(xq,m)))(qm1qm2))T(cos(Θk(xk,n))sin(Θk(xk,m))−sin(Θk(xk,n))cos(Θk(xk,m)))(k