自注意力机制优化方向

链接:https://www.zhihu.com/question/602057035/answer/3169463347


自注意力(self-attention)只是注意力机制的一种。 传统的注意力机制是关注如何将输入序列与输出序列关联起来,特别是在源序列和目标序列有所不同的场景,例如在机器翻译中。 自注意力机制则关注的是单一序列内部的元素如何关联,即序列的元素与自身其他元素之间的关系。

自注意力的原理如下:假如我们有一个长度为nn的序列xx。xx的每个元素中都是一个dd维的向量。在自然语言处理中,每一个dd维的向量都可以看作是一个token embedding,把这样一条序列通过三个权重矩阵进行变换,得到三个维度为n×dn\times d的矩阵,可以定义自注意力的一般公式如下

Attention⁡(Q,K,V)=Score⁡(Q,K)V\operatorname{Attention}(Q,K,V)=\operatorname{Score}(Q,K)V \\

其中,最常用的score函数是softmax函数。采用softmax,并加上一个缩放因子,我们就得到了最经典的scaled-dot product attention (SDP):

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V \\

简单来说,该方法通过QQ和KTK^T的乘积计算输入xx的注意力,加上一个缩放因子,并对每行进行softmax归一化,最后与VV相乘得到一个n×dn\times d的输出矩阵。其实就是个加权平均的思想

为了避免点积的值变得过大,QKTQK^T的结果通常除以dk\sqrt{d_k}来进行缩放。当点积的值过大时,会导致 softmax 函数的激活值在其饱和区域,从而使得梯度很小,这对模型的学习是不利的。(softmax 函数将一个向量转换为概率分布。这个函数对于其输入中较大的值特别敏感,这意味着差异较大的值之间的差异会在 softmax 操作后被放大。这种敏感性使得当点积特别大时,softmax 可能只在一个位置具有高概率,而其他位置则接近于零。( 不过这个缩放操作对计算的整体复杂度没有影响,因为最复杂的部分在于softmax(QK)softmax(QK)的计算。

可以发现QQ和KK的乘积会产生一个n×nn\times n的矩阵,而计算这种大型矩阵的行方向softmax的复杂度是O(n2)O(n^2)。当nn很大时,这在计算和内存上都是一个挑战。对于超长序列,计算全局自注意力(每个元素与其他每个元素相乘)会变得非常困难,因此也就有了后来的一些研究对自注意力机制的复杂度进行优化,当然这样做也可能会失去自注意力能够处理长距离上下文的优势。

注意力机制优化

为了简化scaled-dot product attention的复杂度,通常会假设序列中的每一个位置并不是同等重要的,比如一个词可能跟它附近的词比较相关,距离太远的词并不一定需要关注。

为了避免计算全局注意力(当然也可能会牺牲一定的性能),目前常见的有以下几种做法:

稀疏 attention

最简单的比如窗口注意力,其实就是一个token只考虑周围一个窗口内的其他token

import torch
import torch.nn.functional as F

class LocalAttention(torch.nn.Module):
    def __init__(self, embed_size, window_size):
        super(LocalAttention, self).__init__()
        self.window_size = window_size

        # Query, Key, Value linear projections
        self.query = torch.nn.Linear(embed_size, embed_size)
        self.key = torch.nn.Linear(embed_size, embed_size)
        self.value = torch.nn.Linear(embed_size, embed_size)

    def forward(self, x):
        """
        x: [batch_size, seq_length, embed_size]
        """
        B, L, E = x.size()

        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)

        outputs = []
        for i in range(L):
            # Define the local window limits
            start = max(0, i - self.window_size)
            end = min(L, i + self.window_size + 1)

            # Extract local chunks
            local_queries = queries[:, i, :].unsqueeze(1)  # [B, 1, E]
            local_keys = keys[:, start:end, :]  # [B, W, E]
            local_values = values[:, start:end, :]  # [B, W, E]

            # Local attention score
            scores = torch.bmm(local_queries, local_keys.transpose(1, 2)) / E**0.5  # [B, 1, W]
            attn_probs = F.softmax(scores, dim=-1)  # [B, 1, W]

            # Compute output
            output = torch.bmm(attn_probs, local_values).squeeze(1)  # [B, E]
            outputs.append(output)

        return torch.stack(outputs, dim=1)  # [B, L, E]

# Example usage:
input_tensor = torch.randn(32, 100, 512)  # batch of 32, sequence length of 100, embedding size of 512
local_attn = LocalAttention(embed_size=512, window_size=5)
output_tensor = local_attn(input_tensor)
print(output_tensor.shape)  # [32, 100, 512]

类似思想的还有Sparse transformer、Longformer,核心思想都是考虑一些窗口、间隔之类的操作来让一个token避免计算全局注意力,但又不能目光短浅

矩阵分解

在矩阵因子分解方法中,我们通常认为注意力矩阵是低秩的,这意味着矩阵里的元素并不都是相互独立的。所以,我们可以将这个矩阵拆解并使用一个更小的矩阵来近似它,从而能更高效地计算softmax的结果。

Linformer的核心思想就是通过对注意力矩阵的低秩分解来降低计算复杂度。首先,他们通过实验证明,当应用奇异值分解(SVD)时,只需用其前几个最大的奇异值就可以恢复注意力矩阵,这说明注意力矩阵是低秩的。接下来,他们利用 Johnson-Lindenstraus 引理证明注意力矩阵可以用极低的误差被近似为一个低秩矩阵

同时作者也提到,为每一个自注意力矩阵计算 SVD 实际上会引入更多的计算复杂度。因此,他们选择在 VV 和 KK 后面加上两个线性投影矩阵,这样就可以将原来的 (nd)(nd) 矩阵有效地映射到 (kd)(kd) 的低维矩阵,其中 kk 是减小后的维度。换个角度想想,其实类似于一个句子的长度虽然有nn,但是压缩到kk其实就可以保留大部分信息了

import torch
import torch.nn as nn

class SimplifiedLinformerSelfAttention(nn.Module):
    def __init__(self, dim, seq_len, k=256, heads=8, dropout=0.):
        super().__init__()

        self.heads = heads
        self.k = k
        self.seq_len = seq_len

        # 定义全连接层以获得Q,K,V
        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_k = nn.Linear(dim, dim, bias=False)
        self.to_v = nn.Linear(dim, dim, bias=False)

        # 定义投影参数
        self.proj_k = nn.Parameter(torch.randn(seq_len, k))
        self.proj_v = nn.Parameter(torch.randn(seq_len, k))

        # 定义Dropout层和输出层
        self.dropout = nn.Dropout(dropout)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x, context=None):
        b, n, d = x.shape

        # 获取Q, K, V
        queries = self.to_q(x)
        keys = self.to_k(x)
        values = self.to_v(x)

        # 使用投影参数对K, V进行投影
        keys = torch.einsum('bnd,nk->bkd', keys, self.proj_k)
        values = torch.einsum('bnd,nk->bkd', values, self.proj_v)

        # 注意力计算
        dots = torch.einsum('bnd,bkd->bnk', queries, keys)
        attn = self.dropout(dots.softmax(dim=-1))
        out = torch.einsum('bnk,bkd->bnd', attn, values)

        return self.to_out(out)

局部敏感哈希

局部敏感哈希(LSH)是一种高效寻找近似最近邻的技巧。其核心思想是选择特定的哈希函数,使得在高维空间里,两个点p和q如若靠近,则它们的哈希值应相同。这样,所有的点就可以被分配到不同的哈希桶中,大大提高了寻找某个点的最近邻的效率,因为我们只需考虑同一个哈希桶内的点。在自注意力机制中,这种方法可以用于快速计算P,方法是在Q和K上应用LSH,仅对近似的元素进行计算,而非直接进行Q和K的全量计算。

Reformer提出了在自注意力机制中使用LSH以提高效率。他们强调,因为softmax运算主要受到最大值的影响,所以在Q中的每个查询qi只需考虑K中与它最接近或在同一哈希桶的键。下面是一个简化版的LSHSelfAttention代码实现

import torch
import torch.nn.functional as F

class LSHSelfAttention(torch.nn.Module):
    def __init__(self, dim, heads=8, bucket_size=64, n_hashes=1):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.bucket_size = bucket_size
        self.n_hashes = n_hashes

        self.query = torch.nn.Linear(dim, dim * heads, bias=False)
        self.key = torch.nn.Linear(dim, dim * heads, bias=False)
        self.value = torch.nn.Linear(dim, dim * heads, bias=False)
        self.out = torch.nn.Linear(dim * heads, dim)

    def forward(self, x):
        # Splitting the input into multiple heads
        Q = self.query(x).view(x.size(0), -1, self.heads, self.dim // self.heads).transpose(1, 2)
        K = self.key(x).view(x.size(0), -1, self.heads, self.dim // self.heads).transpose(1, 2)
        V = self.value(x).view(x.size(0), -1, self.heads, self.dim // self.heads).transpose(1, 2)

        # Hashing to buckets using LSH
        hash_size = x.size(1) // self.bucket_size
        hashes = ((self.bucket_size * K[..., 0]) + K[..., 1]).argmax(dim=-1) % hash_size
        for _ in range(self.n_hashes - 1):
            hashes = torch.cat([hashes, ((self.bucket_size * K[..., 0]) + K[..., 1]).argmax(dim=-1) % hash_size], -1)

        # Computing attention within each bucket
        attn = torch.zeros_like(K)
        for i in range(hashes.size(-1)):
            mask = hashes == i
            attn += (Q[mask] @ K[mask].transpose(-2, -1)) / self.dim ** 0.5
        attn = F.softmax(attn, dim=-1)
        out = attn @ V

        # Combining heads
        out = out.transpose(1, 2).contiguous().view(x.size(0),

Kernel attention

核技巧是一种在机器学习中经常使用的方法,特别是在支持向量机中。核技巧的主要思想是通过某种函数(核函数)隐式地在高维空间中计算点积,而无需显式地计算高维表示。这允许算法在原始空间进行计算,同时实际上它在高维特征空间中工作。

Kernel attention 是一种近似的注意力机制,其主要思想是使用核技巧(kernel trick)来估计原始注意力的计算。这种方法尤其在长序列上很有效,因为它可以显著减少计算和存储的需求。

常见的方法比如FAVOR+,它的关键思想是使用特征映射 ϕ \phi 来近似这个 softmax 函数,从而避免显式地进行求和操作:

aij≈ϕ(sij)Tϕ(sik)a_{ij} \approx \phi(s_{ij})^T \phi(s_{ik}) \\

其中 ϕ \phi 是随机特征映射,选择 ϕ \phi 的一个有效方法是使用正交随机特征。通过选择合适的 ϕ \phi ,FAVOR+ 可以在减少计算复杂度的同时保持较高的准确性。

以下是基于 PyTorch 实现的一个简化版 FAVOR+:

import torch
import torch.nn as nn
import torch.nn.functional as F

class FavorAttention(nn.Module):
    def __init__(self, dim, num_features=256):
        super(FavorAttention, self).__init__()
        self.dim = dim
        self.num_features = num_features

        # 随机特征映射
        self.omega = torch.randn(self.dim, self.num_features)
        self.bias = 2 * torch.pi * torch.rand(self.num_features)

    def feature_map(self, x):
        # 使用 cos 和 sin 函数进行特征映射
        return torch.cat([torch.sin(x @ self.omega + self.bias), 
                          torch.cos(x @ self.omega + self.bias)], dim=-1) / self.num_features**0.5

    def forward(self, Q, K, V):
        # 近似注意力权重计算
        Q_features = self.feature_map(Q)
        K_features = self.feature_map(K)
        
        A = Q_features @ K_features.transpose(-2, -1)
        return A @ V

# 测试
batch_size, seq_len, d_model = 32, 100, 64
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

attention = FavorAttention(d_model)
output = attention(Q, K, V)

KV-Cache

自回归模型(例如GPT)的推理过程中,我们通常使用增量解码一次生成一个输出token。对于每一个新的token,我们可以利用在之前步骤中做的计算,而不必为整个序列重新做整个注意力计算。这不仅节省了计算开销,而且加速了生成过程。

KV-Cache的主要思路是:当我们一次生成一个token时,之前token的key和value不会改变。因此,我们可以缓存(或记住)这些值,并在下一个token的计算中重复使用它们。

具体来说,对于每一个生成的token:

  1. 为新令牌计算 K K 和 V V 。
  2. 将这个新的 K K 和 V V 添加到缓存的值中。
  3. 使用整个 K K 和 V V 序列(旧的缓存值+新值)进行注意力机制的计算。

基于这一想法的示例代码实现如下:

import torch.nn as nn

class IncrementalAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(IncrementalAttention, self).__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // self.num_heads

        # Q, K, V 的线性层
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

        # 初始化K和V的空缓存
        self.k_cache = None
        self.v_cache = None

    def forward(self, q, k, v, mask=None):
        # 计算Q
        q = self.WQ(q).view(-1, self.num_heads, self.depth)

        # 计算新令牌的K和V
        k_new = self.WK(k).view(-1, self.num_heads, self.depth)
        v_new = self.WV(v).view(-1, self.num_heads, self.depth)

        # 添加到缓存
        if self.k_cache is not None:
            k = torch.cat([self.k_cache, k_new], dim=1)
            v = torch.cat([self.v_cache, v_new], dim=1)
        else:
            k = k_new
            v = v_new
        
        # 更新缓存以供下一次迭代使用
        self.k_cache = k
        self.v_cache = v

        # 注意力机制(简化了,以便简短)
        scores = torch.matmul(q, k.transpose(1, 2)) / self.depth**0.5
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = self.softmax(scores)
        output = torch.matmul(attn_weights, v)

        return output

Multi-Query Attention

在2019年,Shazeer 提出了一种称为多查询注意力(MQA,Multi-Query Attention)的多头注意力(MHA,Multi-Head Attention)的改进算法。它提高了注意力计算的机器效率,同时只造成了较小的准确性降低。这一方法目前也被广泛用在LLM中。

传统的多头注意力实质上是将输入分成多个头部,并为每个头部独立计算注意力。在MHA中,QQ、KK 和 VV 都根据每个head进行不同的转换。这在头部数量较多时可能会计算密集。多查询注意力简化了这个过程,尤其是在KK和VV的部分。与为每个head提供多个、单独的KK和VV映射不同,MQA为所有head应用单一的KK和VV转换。只有QQ值才有多个head。

通过在head之间共享相同的K和V转换,参数和操作的数量大大减少。

在传统的多头注意力中,每个head的转换可能看起来像这样:

for i in range(num_heads):
    Qi = WQi @ Q
    Ki = WKi @ K
    Vi = WVi @ V
    ...

在多查询注意力中,这将更改为:

K_shared = WK @ K
V_shared = WV @ V

for i in range(num_heads):
    Qi = WQi @ Q
    ...

其中WKWV是K和V的共享权重,而WQi矩阵对于每个查询头都是不同的。

Grouped-Query Attention

了解了Multi-head和Multi-Query的思想后,Grouped-Query Attention就很好理解了,其实就是它们的一个折中方案,KK和VV的数量减少一些,但又不是只有一组这么少。

实现上其实就是KK和VV只需要几个head,然后通过repeat_kv复制多份得到维度和QQ一样的tensor,从而能够进行注意力计算

xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# xxxxxxxx

# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值