一种可理解的线性transformer

找到一种更简洁的形式,如下:

message passing公式。其含义是,每个token产生一个消息 Q i Q_i Qi。然后消息通过权重 m i m_i mi加权合并,最后通过权重 s i s_i si消息分发。极限情况下(即softmax中有一个token的值为1,其余token值为0)每一次传递,会将某个token的消息传给另一个token。问题在于极限情况下,每一层只能设定固定的传递次数。在全连接图的情况下,传递次数应为 O ( n 2 ) O(n^2) O(n2)。在softmax情况下,这种情况可能有所改善。权重平均的情况下,相当于每个token加入了一个相同常数。令 m i = w 1 x i , s i = w 2 x i m_i = w_1x_i, s_i = w_2x_i mi=w1xi,si=w2xi,有
Q = ∑ i e x p ( m i ) ∑ j e x p ( m j ) Q i Q = \sum_i\dfrac{exp(m_i)}{\sum_jexp(m_j)}Q_i Q=ijexp(mj)exp(mi)Qi

x n ′ = x n + e x p ( s i ) ∑ j e x p ( s j ) Q x_n' = x_n + \dfrac{exp(s_i)}{\sum_jexp(s_j)}Q xn=xn+jexp(sj)exp(si)Q
如果进行 k k k次合并,和 k k k次分发,总体公式如下:

x ′ = x + s o f t m a x ( S x , d i m = l ) s o f t m a x ( M x , d i m = l ) T Q x x' = x + softmax(Sx,dim = l) softmax(Mx, dim=l)^T Qx x=x+softmax(Sx,dim=l)softmax(Mx,dim=l)TQx

,其中 Q x : [ b a t c h , l e n , c ] , s o f t m a x ( M x ) T : [ b , k , l ] , e x p ( S x ) : [ b , l , k ] ,其中Qx:[batch, len, c], softmax(Mx)^T:[b, k, l], exp(Sx):[b,l,k] ,其中Qx:[batch,len,c],softmax(Mx)T:[b,k,l],exp(Sx):[b,l,k]

从公式上来看,更像一个低秩方法。需要看一下与其它的低秩方法的区别。

与之对比的self_attention公式如下,可以发现其实就是拆解了self attention的softmax,比较巧合。

x ′ = x + s o f t m a x ( ( Q x ) T ( K x ) , d i m = l ) ∗ ( V x ) x' = x + softmax((Qx)^T(Kx),dim=l) * (Vx) x=x+softmax((Qx)T(Kx),dim=l)(Vx)

其中

Q x : [ b , l , k ] , K x : [ b , k , l ] , V x : [ b , l , c ] Qx:[b,l,k], Kx:[b,k,l], Vx:[b,l,c] Qx:[b,l,k],Kx:[b,k,l],Vx:[b,l,c]

参考了一下linear transformer的论文,主要是https://blog.csdn.net/hymn1993/article/details/125254897。发现确实是低秩方法的一种,只不过用softmax作为核进行映射。不过比较巧妙的是,这套方法有明确的可解释意义。

transformer O ( n 2 ) O(n^2) O(n2)复杂度的关键点在于,对每个token都查询了一次。因此降低复杂度的一个行之有效的方法是降低查询的次数。因此提出竞争查询的方法。公式如下:

Q = ∑ i e x p ( z i ) ∑ j e x p ( z j ) Q i Q = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}Q_i Q=ijexp(zj)exp(zi)Qi

x n ′ = x n + e x p ( z n ) ∑ j e x p ( z j ) ∑ m ( e x p ( Q ∗ K m ) ∑ k e x p ( Q ∗ K k ) V m ) x_n' = x_n + \dfrac{exp(z_n)}{\sum_jexp(z_j)}\sum_m (\dfrac{exp( Q* K_m)}{\sum_kexp(Q*K_k)}V_m) xn=xn+jexp(zj)exp(zn)m(kexp(QKk)exp(QKm)Vm)

Z Z Z向量为竞争向量,通过softmax归一化得到分布在tokens上的权重,根据 Z Z Z的权重对所有的query向量 Q i Q_i Qi进行求和,得到竞争成功的 Q Q Q向量。可以理解为这一步将所有要查询的东西编码到同一个向量中。然后正常按照transformer的办法用Q向量与每个token的key向量 K m K_m Km V m V_m Vm得到更新向量。然后按照 Q Q Q向量的比例,依次按照比例把更新向量赋值给所有的token。可以看出当 Q i Q_i Qi的比例为极限情况(0,0,…,1,…,0,0)时,相当于只对比例为1的token做查询。

另外一种公式是:

z i = w x i z_i = w x_i zi=wxi

Q = ∑ i e x p ( z i ) ∑ j e x p ( z j ) Q i Q = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}Q_i Q=ijexp(zj)exp(zi)Qi

V = ∑ i e x p ( z i ) ∑ j e x p ( z j ) V i V = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}V_i V=ijexp(zj)exp(zi)Vi

x n ′ = x n + r e l u ( Q K n ) V x_n' = x_n + relu(QK_n)V xn=xn+relu(QKn)V

也是具备非常明确的意义。

第二种存在GPT形式(casual mask attention):

要依次根据mask,生成每个token的 Q n , V n Q^n,V^n Qn,Vn。每个token的更新如下:

Q n = ∑ i = 0 n e x p ( z i ) ∑ j = 0 n e x p ( z j ) Q i Q^n = \sum_{i = 0}^n\dfrac{exp(z_i)}{\sum_{j = 0}^nexp(z_j)}Q_i Qn=i=0nj=0nexp(zj)exp(zi)Qi

V n = ∑ i = 0 n e x p ( z i ) ∑ j = 0 n e x p ( z j ) V i V^n = \sum_{i = 0}^n\dfrac{exp(z_i)}{\sum_{j = 0}^nexp(z_j)}V_i Vn=i=0nj=0nexp(zj)exp(zi)Vi

x n ′ = x n + r e l u ( Q n K n ) V n x_n' = x_n + relu(Q^nK_n)V^n xn=xn+relu(QnKn)Vn

注意 Q n Q_n Qn的计算是存在递归公式的,因此其复杂度为seq len线性相关,缺点是无法并行。

第二个公式是有明确意义的。每一层都预定了一个w,用来判断query的重要程度。因此w是一个先验知识,用来决定应该先查询什么,然后再查询什么。但是真正的查询向量Q又是和序列有关的,不同的序列有不同的查询向量。

上面的公式,其根本的意义是,使用基于softmax的全局池化,将任意长的文本序列池化成一个固定长的序列。然后让每个当前字符与该池化后的序列进行QKV attention操作。而且这种方案存在RNN的等价形式,但是又可通过cumsum进行并行训练。在nanoGPT上的具体实现如下:

class linearCausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.memoryCompressDim = 2
        memoryDim = self.memoryCompressDim * config.n_head
        self.c_compress = nn.Linear(config.n_embd, memoryDim, bias=config.bias)
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        w = self.c_compress(x)
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        w = w.view(B, T, self.n_head, self.memoryCompressDim).transpose(1, 2) # (B, nh, T, 4)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2).unsqueeze(4).repeat(1, 1, 1, 1, self.memoryCompressDim) # (B, nh, T, hs, 4)
        q = q.view(B, T, self.n_head, C // self.n_head, 1).transpose(1, 2) # (B, nh, T, hs, 1)

        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2).unsqueeze(4).repeat(1, 1, 1, 1, self.memoryCompressDim) # (B, nh, T, hs, 4)
        weight_exp = torch.exp(w).view(B, self.n_head, T, 1, self.memoryCompressDim) # (B, nh, T, 1, 4)
        k_weight = weight_exp * k # (B, nh, T, hs, 4)
        v_weight = weight_exp * v # (B, nh, T, hs, 4)
        weight_exp_cumsum = torch.cumsum(weight_exp, dim = 2)
        k_weight_cum_sum = torch.cumsum(k_weight, dim = 2)
        v_weight_cumsum = torch.cumsum(v_weight, dim = 2)
        
        k_att = k_weight_cum_sum / weight_exp_cumsum # (B, nh, T, hs, 4)
        v_att = v_weight_cumsum / weight_exp_cumsum # (B, nh, T, hs, 4)

        atten = F.softmax((k_att * q).sum(dim = 3, keepdim=True), 4) # (B, nh, T, 1, 4)
        y = (atten * v_att).sum(dim = 4, keepdim=False) # (B, nh, T, hs)

        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y
class linearCausalSelfAttention2(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        
        self.elu1 = torch.nn.ELU()
        self.elu2 = torch.nn.ELU()
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = self.elu1(k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2).unsqueeze(4)) + 1 # (B, nh, T, hs, 1)
        q = self.elu2(q.view(B, T, self.n_head, C // self.n_head, 1).transpose(1, 2)) + 1 # (B, nh, T, hs, 1)

        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2).unsqueeze(4) # (B, nh, T, hs, 1)
       
       	kv = (k @ v.transpose(-2, -1)) # (B, nh, T, khs, vhs)
        kv_cumsum = torch.cumsum(kv, dim = 2)

        k_cumsum = torch.cumsum(k, dim = 2) # (B, nh, T, hs, 1)
		
        y = (q.transpose(-2, -1) @ kv_cumsum).view(B, T, self.n_head, C // self.n_head) #(B, nh, T, hs)
        y = y / ((q.transpose(-2, -1) @ k_cumsum).view(B, T, self.n_head, 1))
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值