链接:https://www.zhihu.com/question/602057035/answer/3169463347
自注意力(self-attention)只是注意力机制的一种。 传统的注意力机制是关注如何将输入序列与输出序列关联起来,特别是在源序列和目标序列有所不同的场景,例如在机器翻译中。 自注意力机制则关注的是单一序列内部的元素如何关联,即序列的元素与自身其他元素之间的关系。
自注意力的原理如下:假如我们有一个长度为nn的序列xx。xx的每个元素中都是一个dd维的向量。在自然语言处理中,每一个dd维的向量都可以看作是一个token embedding,把这样一条序列通过三个权重矩阵进行变换,得到三个维度为n×dn\times d的矩阵,可以定义自注意力的一般公式如下
Attention(Q,K,V)=Score(Q,K)V\operatorname{Attention}(Q,K,V)=\operatorname{Score}(Q,K)V \\
其中,最常用的score函数是softmax函数。采用softmax,并加上一个缩放因子,我们就得到了最经典的scaled-dot product attention (SDP):
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V \\
简单来说,该方法通过QQ和KTK^T的乘积计算输入xx的注意力,加上一个缩放因子,并对每行进行softmax归一化,最后与VV相乘得到一个n×dn\times d的输出矩阵。其实就是个加权平均的思想
为了避免点积的值变得过大,QKTQK^T的结果通常除以dk\sqrt{d_k}来进行缩放。当点积的值过大时,会导致 softmax 函数的激活值在其饱和区域,从而使得梯度很小,这对模型的学习是不利的。(softmax 函数将一个向量转换为概率分布。这个函数对于其输入中较大的值特别敏感,这意味着差异较大的值之间的差异会在 softmax 操作后被放大。这种敏感性使得当点积特别大时,softmax 可能只在一个位置具有高概率,而其他位置则接近于零。( 不过这个缩放操作对计算的整体复杂度没有影响,因为最复杂的部分在于softmax(QK)softmax(QK)的计算。
可以发现QQ和KK的乘积会产生一个n×nn\times n的矩阵,而计算这种大型矩阵的行方向softmax的复杂度是O(n2)O(n^2)。当nn很大时,这在计算和内存上都是一个挑战。对于超长序列,计算全局自注意力(每个元素与其他每个元素相乘)会变得非常困难,因此也就有了后来的一些研究对自注意力机制的复杂度进行优化,当然这样做也可能会失去自注意力能够处理长距离上下文的优势。
注意力机制优化
为了简化scaled-dot product attention的复杂度,通常会假设序列中的每一个位置并不是同等重要的,比如一个词可能跟它附近的词比较相关,距离太远的词并不一定需要关注。
为了避免计算全局注意力(当然也可能会牺牲一定的性能),目前常见的有以下几种做法:
稀疏 attention
最简单的比如窗口注意力,其实就是一个token只考虑周围一个窗口内的其他token
import torch
import torch.nn.functional as F
class LocalAttention(torch.nn.Module):
def __init__(self, embed_size, window_size):
super(LocalAttention, self).__init__()
self.window_size = window_size
# Query, Key, Value linear projections
self.query = torch.nn.Linear(embed_size, embed_size)
self.key = torch.nn.Linear(embed_size, embed_size)
self.value = torch.nn.Linear(embed_size, embed_size)
def forward(self, x):
"""
x: [batch_size, seq_length, embed_size]
"""
B, L, E = x.size()
queries = self.query(x)
keys = self.key(x)
values = self.value(x)
outputs = []
for i in range(L):
# Define the local window limits
start = max(0, i - self.window_size)
end = min(L, i + self.window_size + 1)
# Extract local chunks
local_queries = queries[:, i, :].unsqueeze(1) # [B, 1, E]
local_keys = keys[:, start:end, :] # [B, W, E]
local_values = values[:, start:end, :] # [B, W, E]
# Local attention score
scores = torch.bmm(local_queries, local_keys.transpose(1, 2)) / E**0.5 # [B, 1, W]
attn_probs = F.softmax(scores, dim=-1) # [B, 1, W]
# Compute output
output = torch.bmm(attn_probs, local_values).squeeze(1) # [B, E]
outputs.append(output)
return torch.stack(outputs, dim=1) # [B, L, E]
# Example usage:
input_tensor = torch.randn(32, 100, 512) # batch of 32, sequence length of 100, embedding size of 512
local_attn = LocalAttention(embed_size=512, window_size=5)
output_tensor = local_attn(input_tensor)
print(output_tensor.shape) # [32, 100, 512]
类似思想的还有Sparse transformer、Longformer,核心思想都是考虑一些窗口、间隔之类的操作来让一个token避免计算全局注意力,但又不能目光短浅。
矩阵分解
在矩阵因子分解方法中,我们通常认为注意力矩阵是低秩的,这意味着矩阵里的元素并不都是相互独立的。所以,我们可以将这个矩阵拆解并使用一个更小的矩阵来近似它,从而能更高效地计算softmax的结果。
Linformer的核心思想就是通过对注意力矩阵的低秩分解来降低计算复杂度。首先,他们通过实验证明,当应用奇异值分解(SVD)时,只需用其前几个最大的奇异值就可以恢复注意力矩阵,这说明注意力矩阵是低秩的。接下来,他们利用 Johnson-Lindenstraus 引理证明注意力矩阵可以用极低的误差被近似为一个低秩矩阵。
同时作者也提到,为每一个自注意力矩阵计算 SVD 实际上会引入更多的计算复杂度。因此,他们选择在 VV 和 KK 后面加上两个线性投影矩阵,这样就可以将原来的 (nd)(nd) 矩阵有效地映射到 (kd)(kd) 的低维矩阵,其中 kk 是减小后的维度。换个角度想想,其实类似于一个句子的长度虽然有nn,但是压缩到kk其实就可以保留大部分信息了。
import torch
import torch.nn as nn
class SimplifiedLinformerSelfAttention(nn.Module):
def __init__(self, dim, seq_len, k=256, heads=8, dropout=0.):
super().__init__()
self.heads = heads
self.k = k
self.seq_len = seq_len
# 定义全连接层以获得Q,K,V
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_k = nn.Linear(dim, dim, bias=False)
self.to_v = nn.Linear(dim, dim, bias=False)
# 定义投影参数
self.proj_k = nn.Parameter(torch.randn(seq_len, k))
self.proj_v = nn.Parameter(torch.randn(seq_len, k))
# 定义Dropout层和输出层
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Linear(dim, dim)
def forward(self, x, context=None):
b, n, d = x.shape
# 获取Q, K, V
queries = self.to_q(x)
keys = self.to_k(x)
values = self.to_v(x)
# 使用投影参数对K, V进行投影
keys = torch.einsum('bnd,nk->bkd', keys, self.proj_k)
values = torch.einsum('bnd,nk->bkd', values, self.proj_v)
# 注意力计算
dots = torch.einsum('bnd,bkd->bnk', queries, keys)
attn = self.dropout(dots.softmax(dim=-1))
out = torch.einsum('bnk,bkd->bnd', attn, values)
return self.to_out(out)
局部敏感哈希
局部敏感哈希(LSH)是一种高效寻找近似最近邻的技巧。其核心思想是选择特定的哈希函数,使得在高维空间里,两个点p和q如若靠近,则它们的哈希值应相同。这样,所有的点就可以被分配到不同的哈希桶中,大大提高了寻找某个点的最近邻的效率,因为我们只需考虑同一个哈希桶内的点。在自注意力机制中,这种方法可以用于快速计算P,方法是在Q和K上应用LSH,仅对近似的元素进行计算,而非直接进行Q和K的全量计算。
Reformer提出了在自注意力机制中使用LSH以提高效率。他们强调,因为softmax运算主要受到最大值的影响,所以在Q中的每个查询qi只需考虑K中与它最接近或在同一哈希桶的键。下面是一个简化版的LSHSelfAttention代码实现
import torch
import torch.nn.functional as F
class LSHSelfAttention(torch.nn.Module):
def __init__(self, dim, heads=8, bucket_size=64, n_hashes=1):
super().__init__()
self.dim = dim
self.heads = heads
self.bucket_size = bucket_size
self.n_hashes = n_hashes
self.query = torch.nn.Linear(dim, dim * heads, bias=False)
self.key = torch.nn.Linear(dim, dim * heads, bias=False)
self.value = torch.nn.Linear(dim, dim * heads, bias=False)
self.out = torch.nn.Linear(dim * heads, dim)
def forward(self, x):
# Splitting the input into multiple heads
Q = self.query(x).view(x.size(0), -1, self.heads, self.dim // self.heads).transpose(1, 2)
K = self.key(x).view(x.size(0), -1, self.heads, self.dim // self.heads).transpose(1, 2)
V = self.value(x).view(x.size(0), -1, self.heads, self.dim // self.heads).transpose(1, 2)
# Hashing to buckets using LSH
hash_size = x.size(1) // self.bucket_size
hashes = ((self.bucket_size * K[..., 0]) + K[..., 1]).argmax(dim=-1) % hash_size
for _ in range(self.n_hashes - 1):
hashes = torch.cat([hashes, ((self.bucket_size * K[..., 0]) + K[..., 1]).argmax(dim=-1) % hash_size], -1)
# Computing attention within each bucket
attn = torch.zeros_like(K)
for i in range(hashes.size(-1)):
mask = hashes == i
attn += (Q[mask] @ K[mask].transpose(-2, -1)) / self.dim ** 0.5
attn = F.softmax(attn, dim=-1)
out = attn @ V
# Combining heads
out = out.transpose(1, 2).contiguous().view(x.size(0),
Kernel attention
核技巧是一种在机器学习中经常使用的方法,特别是在支持向量机中。核技巧的主要思想是通过某种函数(核函数)隐式地在高维空间中计算点积,而无需显式地计算高维表示。这允许算法在原始空间进行计算,同时实际上它在高维特征空间中工作。
Kernel attention 是一种近似的注意力机制,其主要思想是使用核技巧(kernel trick)来估计原始注意力的计算。这种方法尤其在长序列上很有效,因为它可以显著减少计算和存储的需求。
常见的方法比如FAVOR+,它的关键思想是使用特征映射 ϕ \phi 来近似这个 softmax 函数,从而避免显式地进行求和操作:
aij≈ϕ(sij)Tϕ(sik)a_{ij} \approx \phi(s_{ij})^T \phi(s_{ik}) \\
其中 ϕ \phi 是随机特征映射,选择 ϕ \phi 的一个有效方法是使用正交随机特征。通过选择合适的 ϕ \phi ,FAVOR+ 可以在减少计算复杂度的同时保持较高的准确性。
以下是基于 PyTorch 实现的一个简化版 FAVOR+:
import torch
import torch.nn as nn
import torch.nn.functional as F
class FavorAttention(nn.Module):
def __init__(self, dim, num_features=256):
super(FavorAttention, self).__init__()
self.dim = dim
self.num_features = num_features
# 随机特征映射
self.omega = torch.randn(self.dim, self.num_features)
self.bias = 2 * torch.pi * torch.rand(self.num_features)
def feature_map(self, x):
# 使用 cos 和 sin 函数进行特征映射
return torch.cat([torch.sin(x @ self.omega + self.bias),
torch.cos(x @ self.omega + self.bias)], dim=-1) / self.num_features**0.5
def forward(self, Q, K, V):
# 近似注意力权重计算
Q_features = self.feature_map(Q)
K_features = self.feature_map(K)
A = Q_features @ K_features.transpose(-2, -1)
return A @ V
# 测试
batch_size, seq_len, d_model = 32, 100, 64
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
attention = FavorAttention(d_model)
output = attention(Q, K, V)
KV-Cache
在自回归模型(例如GPT)的推理过程中,我们通常使用增量解码一次生成一个输出token。对于每一个新的token,我们可以利用在之前步骤中做的计算,而不必为整个序列重新做整个注意力计算。这不仅节省了计算开销,而且加速了生成过程。
KV-Cache的主要思路是:当我们一次生成一个token时,之前token的key和value不会改变。因此,我们可以缓存(或记住)这些值,并在下一个token的计算中重复使用它们。
具体来说,对于每一个生成的token:
- 为新令牌计算 K K 和 V V 。
- 将这个新的 K K 和 V V 添加到缓存的值中。
- 使用整个 K K 和 V V 序列(旧的缓存值+新值)进行注意力机制的计算。
基于这一想法的示例代码实现如下:
import torch.nn as nn
class IncrementalAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(IncrementalAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.depth = d_model // self.num_heads
# Q, K, V 的线性层
self.WQ = nn.Linear(d_model, d_model)
self.WK = nn.Linear(d_model, d_model)
self.WV = nn.Linear(d_model, d_model)
self.softmax = nn.Softmax(dim=-1)
# 初始化K和V的空缓存
self.k_cache = None
self.v_cache = None
def forward(self, q, k, v, mask=None):
# 计算Q
q = self.WQ(q).view(-1, self.num_heads, self.depth)
# 计算新令牌的K和V
k_new = self.WK(k).view(-1, self.num_heads, self.depth)
v_new = self.WV(v).view(-1, self.num_heads, self.depth)
# 添加到缓存
if self.k_cache is not None:
k = torch.cat([self.k_cache, k_new], dim=1)
v = torch.cat([self.v_cache, v_new], dim=1)
else:
k = k_new
v = v_new
# 更新缓存以供下一次迭代使用
self.k_cache = k
self.v_cache = v
# 注意力机制(简化了,以便简短)
scores = torch.matmul(q, k.transpose(1, 2)) / self.depth**0.5
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = self.softmax(scores)
output = torch.matmul(attn_weights, v)
return output
Multi-Query Attention
在2019年,Shazeer 提出了一种称为多查询注意力(MQA,Multi-Query Attention)的多头注意力(MHA,Multi-Head Attention)的改进算法。它提高了注意力计算的机器效率,同时只造成了较小的准确性降低。这一方法目前也被广泛用在LLM中。
传统的多头注意力实质上是将输入分成多个头部,并为每个头部独立计算注意力。在MHA中,QQ、KK 和 VV 都根据每个head进行不同的转换。这在头部数量较多时可能会计算密集。多查询注意力简化了这个过程,尤其是在KK和VV的部分。与为每个head提供多个、单独的KK和VV映射不同,MQA为所有head应用单一的KK和VV转换。只有QQ值才有多个head。
通过在head之间共享相同的K和V转换,参数和操作的数量大大减少。
在传统的多头注意力中,每个head的转换可能看起来像这样:
for i in range(num_heads):
Qi = WQi @ Q
Ki = WKi @ K
Vi = WVi @ V
...
在多查询注意力中,这将更改为:
K_shared = WK @ K
V_shared = WV @ V
for i in range(num_heads):
Qi = WQi @ Q
...
其中WK
和WV
是K和V的共享权重,而WQi
矩阵对于每个查询头都是不同的。
Grouped-Query Attention
了解了Multi-head和Multi-Query的思想后,Grouped-Query Attention就很好理解了,其实就是它们的一个折中方案,KK和VV的数量减少一些,但又不是只有一组这么少。
实现上其实就是KK和VV只需要几个head,然后通过repeat_kv复制多份得到维度和QQ一样的tensor,从而能够进行注意力计算
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# xxxxxxxx
# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)