一种可理解的线性transformer

找到一种更简洁的形式,如下:

message passing公式。其含义是,每个token产生一个消息 Q i Q_i Qi。然后消息通过权重 m i m_i mi加权合并,最后通过权重 s i s_i si消息分发。极限情况下(即softmax中有一个token的值为1,其余token值为0)每一次传递,会将某个token的消息传给另一个token。问题在于极限情况下,每一层只能设定固定的传递次数。在全连接图的情况下,传递次数应为 O ( n 2 ) O(n^2) O(n2)。在softmax情况下,这种情况可能有所改善。权重平均的情况下,相当于每个token加入了一个相同常数。令 m i = w 1 x i , s i = w 2 x i m_i = w_1x_i, s_i = w_2x_i mi=w1xi,si=w2xi,有
Q = ∑ i e x p ( m i ) ∑ j e x p ( m j ) Q i Q = \sum_i\dfrac{exp(m_i)}{\sum_jexp(m_j)}Q_i Q=ijexp(mj)exp(mi)Qi

x n ′ = x n + e x p ( s i ) ∑ j e x p ( s j ) Q x_n' = x_n + \dfrac{exp(s_i)}{\sum_jexp(s_j)}Q xn=xn+jexp(sj)exp(si)Q
如果进行 k k k次合并,和 k k k次分发,总体公式如下:

x ′ = x + s o f t m a x ( S x , d i m = l ) s o f t m a x ( M x , d i m = l ) T Q x x' = x + softmax(Sx,dim = l) softmax(Mx, dim=l)^T Qx x=x+softmax(Sx,dim=l)softmax(Mx,dim=l)TQx

,其中 Q x : [ b a t c h , l e n , c ] , s o f t m a x ( M x ) T : [ b , k , l ] , e x p ( S x ) : [ b , l , k ] ,其中Qx:[batch, len, c], softmax(Mx)^T:[b, k, l], exp(Sx):[b,l,k] ,其中Qx:[batch,len,c],softmax(Mx)T:[b,k,l],exp(Sx):[b,l,k]

从公式上来看,更像一个低秩方法。需要看一下与其它的低秩方法的区别。

与之对比的self_attention公式如下,可以发现其实就是拆解了self attention的softmax,比较巧合。

x ′ = x + s o f t m a x ( ( Q x ) T ( K x ) , d i m = l ) ∗ ( V x ) x' = x + softmax((Qx)^T(Kx),dim=l) * (Vx) x=x+softmax((Qx)T(Kx),dim=l)(Vx)

其中

Q x : [ b , l , k ] , K x : [ b , k , l ] , V x : [ b , l , c ] Qx:[b,l,k], Kx:[b,k,l], Vx:[b,l,c] Qx:[b,l,k],Kx:[b,k,l],Vx:[b,l,c]

参考了一下linear transformer的论文,主要是https://blog.csdn.net/hymn1993/article/details/125254897。发现确实是低秩方法的一种,只不过用softmax作为核进行映射。不过比较巧妙的是,这套方法有明确的可解释意义。

transformer O ( n 2 ) O(n^2) O(n2)复杂度的关键点在于,对每个token都查询了一次。因此降低复杂度的一个行之有效的方法是降低查询的次数。因此提出竞争查询的方法。公式如下:

Q = ∑ i e x p ( z i ) ∑ j e x p ( z j ) Q i Q = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}Q_i Q=ijexp(zj)exp(zi)Qi

x n ′ = x n + e x p ( z n ) ∑ j e x p ( z j ) ∑ m ( e x p ( Q ∗ K m ) ∑ k e x p ( Q ∗ K k ) V m ) x_n' = x_n + \dfrac{exp(z_n)}{\sum_jexp(z_j)}\sum_m (\dfrac{exp( Q* K_m)}{\sum_kexp(Q*K_k)}V_m) xn=xn+jexp(zj)exp(zn)m(kexp(QKk)exp(QKm)Vm)

Z Z Z向量为竞争向量,通过softmax归一化得到分布在tokens上的权重,根据 Z Z Z的权重对所有的query向量 Q i Q_i Qi进行求和,得到竞争成功的 Q Q Q向量。可以理解为这一步将所有要查询的东西编码到同一个向量中。然后正常按照transformer的办法用Q向量与每个token的key向量 K m K_m Km V m V_m Vm得到更新向量。然后按照 Q Q Q向量的比例,依次按照比例把更新向量赋值给所有的token。可以看出当 Q i Q_i Qi的比例为极限情况(0,0,…,1,…,0,0)时,相当于只对比例为1的token做查询。

另外一种公式是:

z i = w x i z_i = w x_i zi=wxi

Q = ∑ i e x p ( z i ) ∑ j e x p ( z j ) Q i Q = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}Q_i Q=ijexp(zj)exp(zi)Qi

V = ∑ i e x p ( z i ) ∑ j e x p ( z j ) V i V = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}V_i V=ijexp(zj)exp(zi)Vi

x n ′ = x n + r e l u ( Q K n ) V x_n' = x_n + relu(QK_n)V xn=xn+relu(QKn)V

也是具备非常明确的意义。

第二种存在GPT形式(casual mask attention):

要依次根据mask,生成每个token的 Q n , V n Q^n,V^n Qn,Vn。每个token的更新如下:

Q n = ∑ i = 0 n e x p ( z i ) ∑ j = 0 n e x p ( z j ) Q i Q^n = \sum_{i = 0}^n\dfrac{exp(z_i)}{\sum_{j = 0}^nexp(z_j)}Q_i Qn=i=0nj=0nexp(zj)exp(zi)Qi

V n = ∑ i = 0 n e x p ( z i ) ∑ j = 0 n e x p ( z j ) V i V^n = \sum_{i = 0}^n\dfrac{exp(z_i)}{\sum_{j = 0}^nexp(z_j)}V_i Vn=i=0nj=0nexp(zj)exp(zi)Vi

x n ′ = x n + r e l u ( Q n K n ) V n x_n' = x_n + relu(Q^nK_n)V^n xn=xn+relu(QnKn)Vn

注意 Q n Q_n Qn的计算是存在递归公式的,因此其复杂度为seq len线性相关,缺点是无法并行。

第二个公式是有明确意义的。每一层都预定了一个w,用来判断query的重要程度。因此w是一个先验知识,用来决定应该先查询什么,然后再查询什么。但是真正的查询向量Q又是和序列有关的,不同的序列有不同的查询向量。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值