找到一种更简洁的形式,如下:
message passing公式。其含义是,每个token产生一个消息
Q
i
Q_i
Qi。然后消息通过权重
m
i
m_i
mi加权合并,最后通过权重
s
i
s_i
si消息分发。极限情况下(即softmax中有一个token的值为1,其余token值为0)每一次传递,会将某个token的消息传给另一个token。问题在于极限情况下,每一层只能设定固定的传递次数。在全连接图的情况下,传递次数应为
O
(
n
2
)
O(n^2)
O(n2)。在softmax情况下,这种情况可能有所改善。权重平均的情况下,相当于每个token加入了一个相同常数。令
m
i
=
w
1
x
i
,
s
i
=
w
2
x
i
m_i = w_1x_i, s_i = w_2x_i
mi=w1xi,si=w2xi,有
Q
=
∑
i
e
x
p
(
m
i
)
∑
j
e
x
p
(
m
j
)
Q
i
Q = \sum_i\dfrac{exp(m_i)}{\sum_jexp(m_j)}Q_i
Q=i∑∑jexp(mj)exp(mi)Qi
x
n
′
=
x
n
+
e
x
p
(
s
i
)
∑
j
e
x
p
(
s
j
)
Q
x_n' = x_n + \dfrac{exp(s_i)}{\sum_jexp(s_j)}Q
xn′=xn+∑jexp(sj)exp(si)Q
如果进行
k
k
k次合并,和
k
k
k次分发,总体公式如下:
x ′ = x + s o f t m a x ( S x , d i m = l ) s o f t m a x ( M x , d i m = l ) T Q x x' = x + softmax(Sx,dim = l) softmax(Mx, dim=l)^T Qx x′=x+softmax(Sx,dim=l)softmax(Mx,dim=l)TQx
,其中 Q x : [ b a t c h , l e n , c ] , s o f t m a x ( M x ) T : [ b , k , l ] , e x p ( S x ) : [ b , l , k ] ,其中Qx:[batch, len, c], softmax(Mx)^T:[b, k, l], exp(Sx):[b,l,k] ,其中Qx:[batch,len,c],softmax(Mx)T:[b,k,l],exp(Sx):[b,l,k]
从公式上来看,更像一个低秩方法。需要看一下与其它的低秩方法的区别。
与之对比的self_attention公式如下,可以发现其实就是拆解了self attention的softmax,比较巧合。
x ′ = x + s o f t m a x ( ( Q x ) T ( K x ) , d i m = l ) ∗ ( V x ) x' = x + softmax((Qx)^T(Kx),dim=l) * (Vx) x′=x+softmax((Qx)T(Kx),dim=l)∗(Vx)
其中
Q x : [ b , l , k ] , K x : [ b , k , l ] , V x : [ b , l , c ] Qx:[b,l,k], Kx:[b,k,l], Vx:[b,l,c] Qx:[b,l,k],Kx:[b,k,l],Vx:[b,l,c]
参考了一下linear transformer的论文,主要是https://blog.csdn.net/hymn1993/article/details/125254897。发现确实是低秩方法的一种,只不过用softmax作为核进行映射。不过比较巧妙的是,这套方法有明确的可解释意义。
transformer O ( n 2 ) O(n^2) O(n2)复杂度的关键点在于,对每个token都查询了一次。因此降低复杂度的一个行之有效的方法是降低查询的次数。因此提出竞争查询的方法。公式如下:
Q = ∑ i e x p ( z i ) ∑ j e x p ( z j ) Q i Q = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}Q_i Q=i∑∑jexp(zj)exp(zi)Qi
x n ′ = x n + e x p ( z n ) ∑ j e x p ( z j ) ∑ m ( e x p ( Q ∗ K m ) ∑ k e x p ( Q ∗ K k ) V m ) x_n' = x_n + \dfrac{exp(z_n)}{\sum_jexp(z_j)}\sum_m (\dfrac{exp( Q* K_m)}{\sum_kexp(Q*K_k)}V_m) xn′=xn+∑jexp(zj)exp(zn)m∑(∑kexp(Q∗Kk)exp(Q∗Km)Vm)
Z Z Z向量为竞争向量,通过softmax归一化得到分布在tokens上的权重,根据 Z Z Z的权重对所有的query向量 Q i Q_i Qi进行求和,得到竞争成功的 Q Q Q向量。可以理解为这一步将所有要查询的东西编码到同一个向量中。然后正常按照transformer的办法用Q向量与每个token的key向量 K m K_m Km和 V m V_m Vm得到更新向量。然后按照 Q Q Q向量的比例,依次按照比例把更新向量赋值给所有的token。可以看出当 Q i Q_i Qi的比例为极限情况(0,0,…,1,…,0,0)时,相当于只对比例为1的token做查询。
另外一种公式是:
z i = w x i z_i = w x_i zi=wxi
Q = ∑ i e x p ( z i ) ∑ j e x p ( z j ) Q i Q = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}Q_i Q=i∑∑jexp(zj)exp(zi)Qi
V = ∑ i e x p ( z i ) ∑ j e x p ( z j ) V i V = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}V_i V=i∑∑jexp(zj)exp(zi)Vi
x n ′ = x n + r e l u ( Q K n ) V x_n' = x_n + relu(QK_n)V xn′=xn+relu(QKn)V
也是具备非常明确的意义。
第二种存在GPT形式(casual mask attention):
要依次根据mask,生成每个token的 Q n , V n Q^n,V^n Qn,Vn。每个token的更新如下:
Q n = ∑ i = 0 n e x p ( z i ) ∑ j = 0 n e x p ( z j ) Q i Q^n = \sum_{i = 0}^n\dfrac{exp(z_i)}{\sum_{j = 0}^nexp(z_j)}Q_i Qn=i=0∑n∑j=0nexp(zj)exp(zi)Qi
V n = ∑ i = 0 n e x p ( z i ) ∑ j = 0 n e x p ( z j ) V i V^n = \sum_{i = 0}^n\dfrac{exp(z_i)}{\sum_{j = 0}^nexp(z_j)}V_i Vn=i=0∑n∑j=0nexp(zj)exp(zi)Vi
x n ′ = x n + r e l u ( Q n K n ) V n x_n' = x_n + relu(Q^nK_n)V^n xn′=xn+relu(QnKn)Vn
注意 Q n Q_n Qn的计算是存在递归公式的,因此其复杂度为seq len线性相关,缺点是无法并行。
第二个公式是有明确意义的。每一层都预定了一个w,用来判断query的重要程度。因此w是一个先验知识,用来决定应该先查询什么,然后再查询什么。但是真正的查询向量Q又是和序列有关的,不同的序列有不同的查询向量。
上面的公式,其根本的意义是,使用基于softmax的全局池化,将任意长的文本序列池化成一个固定长的序列。然后让每个当前字符与该池化后的序列进行QKV attention操作。而且这种方案存在RNN的等价形式,但是又可通过cumsum进行并行训练。在nanoGPT上的具体实现如下:
class linearCausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.memoryCompressDim = 2
memoryDim = self.memoryCompressDim * config.n_head
self.c_compress = nn.Linear(config.n_embd, memoryDim, bias=config.bias)
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
w = self.c_compress(x)
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
w = w.view(B, T, self.n_head, self.memoryCompressDim).transpose(1, 2) # (B, nh, T, 4)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2).unsqueeze(4).repeat(1, 1, 1, 1, self.memoryCompressDim) # (B, nh, T, hs, 4)
q = q.view(B, T, self.n_head, C // self.n_head, 1).transpose(1, 2) # (B, nh, T, hs, 1)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2).unsqueeze(4).repeat(1, 1, 1, 1, self.memoryCompressDim) # (B, nh, T, hs, 4)
weight_exp = torch.exp(w).view(B, self.n_head, T, 1, self.memoryCompressDim) # (B, nh, T, 1, 4)
k_weight = weight_exp * k # (B, nh, T, hs, 4)
v_weight = weight_exp * v # (B, nh, T, hs, 4)
weight_exp_cumsum = torch.cumsum(weight_exp, dim = 2)
k_weight_cum_sum = torch.cumsum(k_weight, dim = 2)
v_weight_cumsum = torch.cumsum(v_weight, dim = 2)
k_att = k_weight_cum_sum / weight_exp_cumsum # (B, nh, T, hs, 4)
v_att = v_weight_cumsum / weight_exp_cumsum # (B, nh, T, hs, 4)
atten = F.softmax((k_att * q).sum(dim = 3, keepdim=True), 4) # (B, nh, T, 1, 4)
y = (atten * v_att).sum(dim = 4, keepdim=False) # (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
class linearCausalSelfAttention2(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.elu1 = torch.nn.ELU()
self.elu2 = torch.nn.ELU()
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = self.elu1(k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2).unsqueeze(4)) + 1 # (B, nh, T, hs, 1)
q = self.elu2(q.view(B, T, self.n_head, C // self.n_head, 1).transpose(1, 2)) + 1 # (B, nh, T, hs, 1)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2).unsqueeze(4) # (B, nh, T, hs, 1)
kv = (k @ v.transpose(-2, -1)) # (B, nh, T, khs, vhs)
kv_cumsum = torch.cumsum(kv, dim = 2)
k_cumsum = torch.cumsum(k, dim = 2) # (B, nh, T, hs, 1)
y = (q.transpose(-2, -1) @ kv_cumsum).view(B, T, self.n_head, C // self.n_head) #(B, nh, T, hs)
y = y / ((q.transpose(-2, -1) @ k_cumsum).view(B, T, self.n_head, 1))
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y