【LLM加速】注意力优化(基于位置/内容的稀疏注意力 | flashattention)

note

(1)近似注意力:

  • Routing Transformer采用K-means 聚类方法,针对Query和Key进行聚类,类中心向量集合为 { μ i } i = 1 k \left\{\boldsymbol{\mu}_i\right\}_{i=1}^k {μi}i=1k ,其中k 是类中心的个数。每个Query 只与其处在相同簇 (Cluster) 下的Key 进行交互。
  • Reformer 则采用局部敏感哈希 (Local-Sensitive Hashing,LSH) 的方法为每个Query 选择Key-Value 对。其主要思想是使用LSH 函数对Query 和Key 进行哈希计算,将它们划分到多个桶内,以提升在同一个桶内的Query 和Key 参与交互的概率。

(2)在Transformer 结构中,自注意力机制的时间和存储复杂度与序列的长度呈平方的关系,因此占用了大量的计算设备内存并消耗了大量的计算资源。如何优化自注意力机制的时空复杂度、增强计算效率是大语言模型面临的重要问题。

  • 方法一:从近似注意力出发,旨在减少注意力计算和内存需求,提出了稀疏近似、低秩近似等方法。
  • 方法二:从计算加速设备本身的特性出发,研究如何更好地利用硬件特性对Transformer 中的注意力层进行高效计算。

(3)FlashAttention目标是尽可能高效地使用SRAM来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax函数,并且后向传播中不能存储中间注意力矩阵。

一、近似注意力

对一些训练好的Transformer 结构中的注意力矩阵进行分析时发现,其中很多是稀疏的,因此可以通过限制Query-Key 对的数量来降低计算复杂度。这类方法称为稀疏注意力(SparseAttention)机制。可以将稀疏化方法进一步分成基于位置的和基于内容信息的两类。

1. 基于位置的稀疏注意力机制

基于位置的稀疏注意力机制的基本类型如下图,主要包含如下五种类型:全局注意力(Global Attention)、带状注意力(Band Attention)、膨胀注意力(Dilated Attention)、随机注意力(Random Attention)、局部块注意力(Block Local Attention)。

这些注意力机制的区别主要在于它们如何选择序列中的元素来计算注意力权重,这直接影响计算复杂度、处理长距离依赖的能力以及对不同类型任务的适用性。每种注意力机制的关键区别和特点:

  1. 全局注意力(Global Attention):

    • 关键特点:在计算每个位置的注意力时,考虑序列中的所有其他位置。
    • 优点:能够捕获全局依赖性,理论上可以处理任意距离的关系。
    • 缺点:计算复杂度高,随序列长度的平方增长,不适合处理长序列。
  2. 带状注意力(Band Attention):

    • 关键特点:仅在每个位置的一个固定宽度的带内计算注意力权重,通常集中在序列的对角线附近。
    • 优点:减少了计算量,适合捕获局部依赖性。
    • 缺点:可能忽略重要的长距离依赖。
  3. 膨胀注意力(Dilated Attention):

    • 关键特点:通过引入膨胀因子来间隔地选择序列中的元素进行注意力计算,从而覆盖更广的范围。和CNN中的Dilated Conv类似,通过增加空隙以获取更大的感受野
    • 优点:在降低计算复杂度的同时,能够捕获更远的依赖性。
    • 缺点:可能不如全局注意力在捕捉所有长距离依赖上有效。
  4. 随机注意力(Random Attention):

    • 关键特点:随机选择序列中的位置来计算注意力权重。即通过随机采样,提升非局部的交互。
    • 优点:显著降低计算需求,引入随机性可能帮助模型探索更多的依赖关系。
    • 缺点:随机性可能导致忽略一些关键的依赖关系。
  5. 局部块注意力(Block Local Attention):

    • 关键特点:将序列分割成多个块,在这些局部块内计算注意力权重。使用多个不重叠的块Block来限制信息交互。
    • 优点:大幅降低计算复杂度,适合处理长序列。
    • 缺点:如果不允许跨块计算,则可能忽略块间的依赖关系。

总结来说,这些注意力机制通过不同的策略平衡计算复杂度和模型的捕获依赖能力。选择哪种注意力机制取决于特定任务的需求,例如处理长序列数据时可能更倾向于使用带状、膨胀、随机或局部块注意力机制,而在不那么受限于计算资源的情况下,全局注意力可能是最好的选择,因为它能够捕获全局依赖性。

在这里插入图片描述
下面给出带状注意力的栗子:

# query-shape: [bs, seq_len, emb_dim]
def band_attention(query, key, value, band_width):
    """
    Args:
        query, key, value: standard attention inputs
        band_width: The width of the band around the diagonal to compute attention.
    Returns:
        Tensor: The output of the attention mechanism.
    """
    batch_size, seq_len, d_k = query.size()
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)

    # Create a mask to zero out attention scores outside the band
    idxs = torch.arange(seq_len).unsqueeze(0).to(query.device)
    mask = (idxs - idxs.transpose(0, 1)).abs().ge(band_width).to(scores.dtype)
    scores.masked_fill_(mask, float('-inf'))

    attention = F.softmax(scores, dim=-1)
    output = torch.matmul(attention, value)
    return output

# 测试的case
def band_attention_test():
    import torch
    # 假设输入数据的维度
    batch_size = 2
    seq_length = 10
    embed_size = 128
    heads = 8
    # 生成随机数据作为输入
    values = torch.rand(batch_size, seq_length, embed_size)
    keys = torch.rand(batch_size, seq_length, embed_size)
    queries = torch.rand(batch_size, seq_length, embed_size)
    # 定义带宽
    band_width = 3
    # 使用相同的随机数据输入
    band_attention_output = band_attention(queries, keys, values, band_width)
    # Band Attention Output Shape: torch.Size([2, 10, 128])
    print("Band Attention Output Shape:", band_attention_output.shape)

可以看到上面的mask矩阵确实是带状的:
在这里插入图片描述

现有的稀疏注意力机制,通常是基于上述五种基于位置的稀疏注意力机制的复合模式,下图给出了一些典型的稀疏注意力模型:

  • star-transformer:使用带状注意力和全局注意力的组合,只包括一个全局注意力节点和宽度为3的带状注意力,其中任意两个非相邻节点通过一个共享的全局注意力连接,而相邻节点则直接相连。
  • longformer:将上层中的一些带状注意力头部替换为具有扩张窗口的注意力,在增加感受野同时不增加计算量
  • ETC(Extended Transformer Construction):利用带状注意力和外部全局节点注意力(External Global-node Attention)的组合。ETC 稀疏注意力还包括一种掩码机制来处理结构化输入,并采用对比预测编码(Contrastive Predictive Coding,CPC)进行预训练。
  • BigBird:使用带状和全局注意力,还使用额外的随机注意力来近似全连接注意力,此外还揭示了稀疏编码器和稀疏解码器的使用可以模拟任何图灵机

在这里插入图片描述

2. 基于内容的稀疏注意力机制

基于内容的稀疏注意力机制根据输入数据创建稀疏注意力,其中一种很简单的方法是选择和给定查询 (Query) 有很高相似度的键 (Key)。

(1)Routing Transformer:使用聚类

(1)Routing Transformer采用K-means 聚类方法,针对Query和Key进行聚类,类中心向量集合为 { μ i } i = 1 k \left\{\boldsymbol{\mu}_i\right\}_{i=1}^k {μi}i=1k ,其中k 是类中心的个数。每个Query 只与其处在相同簇 (Cluster) 下的Key 进行交互。中心向量采用滑动平均的方法进行更新:
μ ~ ← μ ~ + ( 1 − λ ) ( ∑ i : μ ( q i ) = μ q i + ∑ j : μ ( k j ) = μ k j ) c μ ← λ c μ + ( 1 − λ ) ∣ μ ∣ μ ← μ ~ c μ \begin{gathered} \widetilde{\boldsymbol{\mu}} \leftarrow \tilde{\boldsymbol{\mu}}+(1-\lambda)\left(\sum_{i: \mu\left(\boldsymbol{q}_i\right)=\mu} \boldsymbol{q}_i+\sum_{j: \mu\left(\boldsymbol{k}_j\right)=\mu} \boldsymbol{k}_j\right) \\ c_\mu \leftarrow \lambda c_\mu+(1-\lambda)|\mu| \\ \mu \leftarrow \frac{\widetilde{\boldsymbol{\mu}}}{c_\mu} \end{gathered} μ μ~+(1λ) i:μ(qi)=μqi+j:μ(kj)=μkj cμλcμ+(1λ)μμcμμ

(2)Reformer:使用LSH哈希

(2)Reformer 则采用局部敏感哈希 (Local-Sensitive Hashing,LSH) 的方法为每个Query 选择Key-Value 对。其主要思想是使用LSH 函数对Query 和Key 进行哈希计算,将它们划分到多个桶内,以提升在同一个桶内的Query 和Key 参与交互的概率。假设 b b b 是桶的个数,给定一个大小为 [ D k , b / 2 ] [D k , b / 2] [Dkb/2] 的随机矩阵 R R R , LSH 函数的定义为:
h ( x ) = arg ⁡ max ⁡ ( [ x R ; − x R ] ) h(\boldsymbol{x})=\arg \max ([\boldsymbol{x} R ;-\boldsymbol{x} R]) h(x)=argmax([xR;xR])

如果 h q i = h k j h \boldsymbol{q}_i=h \boldsymbol{k}_j \quad hqi=hkj 时, q i \boldsymbol{q}_i qi 才可以与相应的Key-Value对进行交互。

二、计算加速

1. GPU硬件基础知识

NVIDIA GPU中的内存(显存)按照它们物理上是在GPU芯片内部还是板卡RAM存储芯片上,决定了它们的速度、大小以及访问限制。GPU显存分为:

  • 全局内存(Global memory)
  • 本地内存(Local memory)
  • 共享内存(Shared memory,SRAM)
  • 寄存器内存(Register memory)
  • 常量内存(Constant memory)
  • 纹理内存(Texture memory)

在这里插入图片描述

全局内存和本地内存使用的高带宽显存(High Bandwidth Memory,HBM)位于板卡RAM存储芯片上,该部分内存容量很大。全局内存是所有线程都可以访问,而本地内存则只能当前线程访问。NVIDIA H100中全局内存有80GB空间,其访问速度虽然可以达到3.35TB/s,但是如果全部线程同时访问全局内存时,其平均带宽仍然很低

共享内存和寄存器位于GPU芯片上,因此容量很小,并且共享内存只有在同一个GPU线程块(Thread Block)内的线程才可以共享访问,而寄存器仅限于同一个线程内部才能访问。NVIDIA H100中每个GPU线程块在流式多处理器(Stream Multi-processor,SM)可以使用的共享存储容量仅有228KB,但是其速度非常快,远高于全局内存的访问速度。

根据自注意力机制的原理,在GPU中进行计算时,传统的方法还需要引入两个中间矩阵 S 和 P 并存储到全局内存中。具体计算过程如下:
S = Q × K , P = Softmax ⁡ ( S ) , O = P × V \boldsymbol{S}=\boldsymbol{Q} \times \boldsymbol{K}, \boldsymbol{P}=\operatorname{Softmax}(\boldsymbol{S}), \boldsymbol{O}=\boldsymbol{P} \times \boldsymbol{V} S=Q×K,P=Softmax(S),O=P×V

按照上述计算过程,需要:

  • 首先从全局内存中读取矩阵 Q Q Q K K K ,并将计算好的矩阵 S S S再写入全局内存
  • 之后再从全局内存中获取矩阵 S S S ,计算Softmax得到矩阵 P P P 再写入全局内存
  • 之后读取矩阵 P P P 和矩阵 V V V ,计算得到矩阵 O O O

这样的过程会极大占用显存的带宽。在自注意力机制中,计算速度比内存速度快得多 ,因此计算效率越来越多地受到全局内存访问的瓶颈。

2. flashattention

FlashAttention就是通过利用GPU硬件中的特殊设计,针对全局内存和共享存储的I/O速度的不同,尽可能地避免HBM中读取或写入注意力矩阵。

FlashAttention目标是尽可能高效地使用SRAM来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax函数,并且后向传播中不能存储中间注意力矩阵

在这里插入图片描述

FlashAttention 就提出了不使用中间注意力矩阵,通过存储归一化因子来减少全局内存消耗的方法。

FlashAttention 算法并没有将S、P 整体写入全局内存,而是通过分块写入,存储前向传递的Softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从全局内存中读取中间注意力矩阵的标准方法更快。

虽然大幅减少了全局内存的访问量,重新计算也导致FLOP(FLOPS指标,Floating Point Operations per Second 指每秒浮点运算次数) 增加,但其运行的速度更快且使用的内存更少。

在这里插入图片描述

3. 多查询注意力MQA

多查询注意力(Multi Query Attention)是多头注意力的一种变体。其特点是,在多查询注意力中不同的注意力头共享一个键和值的集合,每个头只单独保留了一份查询参数,因此键和值的矩阵仅有一份,这大幅减少了显存占用,使其更高效。
在这里插入图片描述

由于多查询注意力改变了注意力机制的结构,因此模型通常需要从训练开始就支持多查询注意力。文献研究结果表明,可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约5% 的原始训练数据量就可以达到不错的效果。

包括Falcon[64]、SantaCoder[65]、StarCoder[66] 在内的很多模型都采用了多查询注意力机制。

(1)MHA和MQA的区别

MHA 和 MQA 之间的区别主要在于建立 Wqkv Layer 上(如下代码)。在MQA中,除了query向量还保存8个头,key和value向量都只剩下1个【公共头】,即前面说的所有head之间共享一份key和value参数。

# Multi Head Attention
self.Wqkv = nn.Linear(                        # 【关键】Multi-Head Attention 的创建方法
    self.d_model, 
    3 * self.d_model,                         # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
    device=device
)
query, key, value = qkv.chunk(                # 【关键】每个 tensor 都是 (1, 512, 768)
    3, 
    dim=2
)
  
# Multi Query Attention
self.Wqkv = nn.Linear(                                # 【关键】Multi-Query Attention 的创建方法
    d_model,
    d_model + 2 * self.head_dim,                      # 只创建 query 的 head 向量,所以只有 1 个 d_model
    device=device,                                    # 而 key 和 value 不再具备单独的头向量
)
query, key, value = qkv.split(                        # query -> (1, 512, 768)
    [self.d_model, self.head_dim, self.head_dim],     # key   -> (1, 512, 96)
    dim=2                                             # value -> (1, 512, 96)
)

(2)MHA和MQA的具体代码

其中MultiheadAttentionMultiQueryAttention类完整的代码如下。

class MultiheadAttention(nn.Module):
 
     def __init__(
             self,
             d_model: int,
             n_heads: int,
             device: str
        ):
         """
        Multi Head init func.
 
        Args:
            d_model (int): hidden state size, e.g. 768
            n_heads (int): 设定的注意力头数, e.g. 8
            device (str): _description_
        """
         super().__init__()
 
         self.d_model = d_model
         self.n_heads = n_heads
     
         self.Wqkv = nn.Linear(                       # Multi-Head Attention 的创建方法
             self.d_model,
             3 * self.d_model,                        # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
             device=device
        )                                            # (d_model, 3 * d_model)
         self.attn_fn = scaled_multihead_dot_product_attention
         self.out_proj = nn.Linear(
             self.d_model,
             self.d_model,
             device=device
        )
 
     def forward(
         self,
         x
    ):
         """
        forward func.
 
        Args:
            x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)
 
        Returns:
            _type_: _description_
        """
         qkv = self.Wqkv(x)                            # (1, 768, 3 * 768)
 
         query, key, value = qkv.chunk(                # 每个 tensor 都是 (1, 512, 768)
             3,
             dim=2
        )    
 
         context, attn_weights, past_key_value = self.attn_fn(
             query,
             key,
             value,
             self.n_heads
        )                                             # (1, 512, 768)
 
         return self.out_proj(context), attn_weights, past_key_value
 
 
 class MultiQueryAttention(nn.Module):
     """Multi-Query self attention.
 
    Using torch or triton attention implemetation enables user to also use
    additive bias.
    """
 
     def __init__(
         self,
         d_model: int,
         n_heads: int,
         device: Optional[str] = None,
    ):
         super().__init__()
 
         self.d_model = d_model
         self.n_heads = n_heads
         self.head_dim = d_model // n_heads
 
         self.Wqkv = nn.Linear(                           # Multi-Query Attention 的创建方法
             d_model,
             d_model + 2 * self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
             device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
        )
 
         self.attn_fn = scaled_multihead_dot_product_attention
         self.out_proj = nn.Linear(
             self.d_model,
             self.d_model,
             device=device
        )
         self.out_proj._is_residual = True  # type: ignore
 
     def forward(
         self,
         x,
    ):
         qkv = self.Wqkv(x)                                           # (1, 512, 960)
 
         query, key, value = qkv.split(                               # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim],             # key   -> (1, 512, 96)
             dim=2                                                    # value -> (1, 512, 96)
        )
 
         context, attn_weights, past_key_value = self.attn_fn(
             query,
             key,
             value,
             self.n_heads,
             multiquery=True,
        )
 
         return self.out_proj(context), attn_weights, past_key_value

(1)初始化函数 __init__

  • __init__(self, d_model: int, n_heads: int, device: Optional[str] = None): 这是类的初始化函数,用于创建类的实例时初始化其属性。它接受三个参数:模型的维度 d_model、注意力头的数量 n_heads,以及设备 device(可选),用于指定模块运行的硬件(CPU或GPU)。
  • self.d_model = d_modelself.n_heads = n_heads: 这两行代码将传入的模型维度和头的数量保存为类的属性。
  • self.head_dim = d_model // n_heads: 计算每个头的维度,即将模型维度均分到每个头上。
  • self.Wgkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device): 创建一个线性层 Wgkv,用于生成查询(Q)、键(K)和值(V)。这个线性层的输出维度是 d_model + 2 * self.head_dim,意味着查询的维度保持为 d_model,而键和值的维度为 self.head_dim。这种设计减少了模型参数,因为它没有为键和值分别创建额外的线性变换。
  • self.attn_fn = scaled_multihead_dot_product_attentionself.out_proj = nn.Linear(self.d_model, self.d_model, device=device): 定义了一个注意力函数 attn_fn 和一个输出投影层 out_projattn_fn 负责计算多头点积注意力,而 out_proj 用于将注意力机制的输出转换回原始输入的维度。

(2)前向传播函数 forward

  • def forward(self, X): 定义了前向传播函数,它接收一个输入张量 X
  • gkv = self.Wgkv(X): 首先,输入通过 Wgkv 线性层,产生了合并的查询、键、值矩阵。
  • query, key, value = gkv.split([self.d_model, self.head_dim, self.head_dim], dim=2): 然后,将 gkv 拆分为查询、键和值三部分。注意拆分的维度与 Wgkv 层的输出设计相匹配。
  • context, attn_weights, past_key_value = self.attn_fn(query, key, value, self.n_heads, multiquery=True): 使用定义的注意力函数计算注意力,multiquery=True 参数指示使用多查询注意力机制。
  • return self.out_proj(context), attn_weights, past_key_value: 最后,将注意力的输出通过 out_proj 投影层,然后将结果、注意力权重和过去的键值对返回。

(3)使用矩阵乘法matmul广播实现参数共享

其中注意上面的scaled_multihead_dot_product_attention函数就是实现刚才说的一份key和value参数让多个头使用,使用矩阵乘法matmul进行广播,实现参数共享。

def scaled_multihead_dot_product_attention(
        query,
        key,
        value,
        n_heads,
        multiquery=False,
    ):
    q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)
    kv_n_heads = 1 if multiquery else n_heads
    k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 96, 512)  if multiquery
    v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 512, 96)  if multiquery
    
    attn_weight = q.matmul(k) * softmax_scale                       # (1, 8, 512, 512)
    attn_weight = torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)
 
    out = attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
    out = rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)
 
    return out, attn_weight, past_key_value

(4)tgi框架中的MQA

具体还可以参考tgi框架中的MQA代码:

class MultiQueryAttention(nn.Module):
    """Multi-Query self attention.

    Using torch or triton attention implementation enables user to also use
    additive bias.
    """

    def __init__(self, config, prefix, weights):
        super().__init__()
        attn_impl = config.attn_config["attn_impl"]
        self.attn_impl = config.attn_config["attn_impl"]
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.d_model = config.d_model
        d_model = config.d_model
        self.n_heads = config.n_heads
        self.softmax_scale = config.attn_config["softmax_scale"]
        if self.softmax_scale is None:
            self.softmax_scale = 1 / math.sqrt(self.head_dim)
        self.attn_dropout_p = config.attn_config["attn_pdrop"]
        # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
        self.Wqkv = TensorParallelColumnLinear.load(
            config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
        )
        fuse_splits = (d_model, d_model + self.head_dim)
        if self.qk_ln:
            raise NotImplementedError("qk_ln not supported")
        if self.attn_impl == "flash":
            self.attn_fn = flash_attn_fn
        elif self.attn_impl == "triton":
            self.attn_fn = triton_flash_attn_fn
            if verbose:
                warnings.warn(
                    "While `attn_impl: triton` can be faster than `attn_impl: flash` "
                    + "it uses more memory. When training larger models this can trigger "
                    + "alloc retries which hurts performance. If encountered, we recommend "
                    + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
                )
        elif self.attn_impl == "torch":
            self.attn_fn = scaled_multihead_dot_product_attention
            if torch.cuda.is_available() and verbose:
                warnings.warn(
                    "Using `attn_impl: torch`. If your model does not use `alibi` or "
                    + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
                    + "we recommend using `attn_impl: triton`."
                )
        else:
            raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
        self.out_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.out_proj",
            weights=weights,
            bias=not config.no_bias,
        )
        # self.out_proj._is_residual = True

    def forward(
        self,
        x,
        past_key_value=None,
        attn_bias=None,
        attention_mask=None,
        is_causal=True,
        needs_weights=False,
    ):
        qkv = self.Wqkv(x)
        if self.clip_qkv:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        (query, key, value) = qkv.split(
            [self.d_model, self.head_dim, self.head_dim], dim=2
        )
        key_padding_mask = attention_mask
        if self.qk_ln:
            dtype = query.dtype
            query = self.q_ln(query).to(dtype)
            key = self.k_ln(key).to(dtype)
        (context, attn_weights, past_key_value) = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            past_key_value=past_key_value,
            softmax_scale=self.softmax_scale,
            attn_bias=attn_bias,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            dropout_p=self.attn_dropout_p,
            training=self.training,
            needs_weights=needs_weights,
            multiquery=True,
        )
        return (self.out_proj(context), attn_weights, past_key_value)

Reference

[1] https://github.com/huggingface/text-generation-inference
[2] LLM 加速技巧:Muti Query Attention
[3] 训练模型算力的单位:FLOPs、FLOPS、Macs 与 估算模型(FC, CNN, LSTM, Transformers&&LLM)的FLOPs
[4] FlashAttention 的速度优化原理是怎样的?
[5] FlashAttention图解(如何加速Attention)
[6] flashattention论文:https://arxiv.org/pdf/2205.14135.pdf
[7] 全局注意力机制(global attention)详解与代码实现
[8] 每天学习一点点—大模型知识学习

<think>好的,我现在需要帮用户全面梳理大语言模型(LLM)领域的核心知识点,形成全维度的知识体系。用户之前已经提供过一个结构化的总结,现在需要根据新的引用内容和之前的回答进行补充和完善。首先,我得仔细看看用户提供的引用内容,看看哪些可以整合进去。 用户提到的引用[1]是关于天文大语言模型StarGLM,这个项目基于ChatGLM,专注于缓解天文领域的幻觉现象,并且为多模态任务和观测Agent做准备。这应该添加到应用挑战或实践案例部分,特别是在幻觉问题的解决方案中,可以提到StarGLM作为具体案例。同时,Agent部分也可以引用这个司天大脑的例子。 引用[2]讨论了训练LLM的三个阶段:预训练、监督微调(SFT)和人类对齐训练。这部分需要整合到训练方法体系中,可能需要在无监督预训练和参数高效微调部分补充相关内容,尤其是SFT的作用和应用场景,比如用户提到的场景存在特殊知识需要知识灌注,可能需要继续训练+全量训练,或者风格定制需要对齐训练,这些可以放在训练方法的不同阶段说明。 引用[3]提到了不同训练方法的适用场景,比如知识灌注、回复风格定制、模型能力不足时的微调等。这部分应该补充到参数高效微调技术或训练方法体系的章节中,帮助理解何时选择哪种微调策略。比如全量训练、轻量训练的应用场景,可以结合表格或说明来增强可读性。 接下来,我需要回顾之前的回答,确保新的内容与之前的结构保持一致,并且覆盖用户提到的所有知识点:提示词工程、注意力机制位置编码、参数高效微调技术、幻觉问题、对齐难度、无监督学习、token、卷积神经网络、Agent、多模态、最新技术进展、实践案例。 需要注意用户提到的卷积神经网络,虽然LLM主要基于Transformer,但可能在某些预处理或特定任务中用到CNN,比如多模态中的图像处理部分,例如Flamingo或BLIP-2可能结合CNN处理视觉特征,需要补充这一点。 在整合引用内容时,要自然地添加引用标识,比如StarGLM的例子引用[^1],训练阶段的SFT引用[^2],适用场景引用[^3]。同时,确保每个模块(核心技术、训练方法、应用挑战、前沿技术)都有最新的案例和技术进展,比如MoE架构、轻量化部署、Agent系统等。 需要检查是否有重复的内容,比如幻觉问题在之前的回答中已经提到过解决方案,现在加入StarGLM作为案例。训练方法中的参数高效微调部分需要结合引用[3]中的适用场景,说明不同微调技术的使用情况。 最后,确保回答结构清晰,分点明确,每个知识点都有核心技术、训练方法、挑战与解决方案、前沿融合等方面的覆盖,并且引用适当,符合用户要求的格式。同时,补充用户可能遗漏的知识点,如token的处理、CNN在LLM中的角色等,以形成全面的知识体系。</think>以下是大语言模型(LLM)全维度知识体系总结,整合最新技术进展与实践案例: --- ### 一、**核心技术模块** 1. **提示词工程** - 核心方法: - 零样本提示(Zero-Shot):直接执行未训练任务 - 小样本提示(Few-Shot):通过示例引导模型行为 - 思维链(CoT):分步推理提升复杂任务准确性[^1] - 创新方向: - **程序辅助语言模型(PAL)**:将自然语言指令映射为可执行代码逻辑 - **结构化模板**:Claude 3通过代码模板约束生成格式 2. **注意力机制** - 数学表达: $$ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$ - 优化技术: - **FlashAttention-2**:GPU显存利用率提升3倍 - **滑动窗口注意力**(Mistral-7B):限制上下文窗口降低计算复杂度 - 多模态扩展:Flamingo模型通过**交叉注意力**对齐文本与图像特征[^2] 3. **位置编码** - 绝对编码:Transformer原始正余弦函数 $$ PE_{(pos,2i)}=\sin(pos/10000^{2i/d}) $$ - 相对编码: - **RoPE**(旋转位置编码):LLaMA2支持4096 tokens长文本 - **ALiBi**:通过线性偏置实现动态上下文扩展 4. **卷积神经网络融合** - 视觉处理:BLIP-2使用ViT-CNN混合架构提取图像特征 - 时序建模:Video-LLaMA引入时间卷积处理视频帧序列[^2] --- ### 二、**训练方法体系** 1. **参数高效微调技术** | 技术 | 参数量占比 | 典型场景 | |-------------|------------|---------------------------| | LoRA | <5% | Stable Diffusion风格迁移 | | Adapter | 3-10% | LLaMA-Adapter多任务学习 | | Prefix Tuning| 0.1% | 对话系统领域适配 | - **知识灌注场景**:天文模型StarGLM通过全量训练注入领域知识[^1] 2. **无监督预训练** - 数据策略: - **PaLM数据集去重**:1.56T tokens清洗后数据 - **Chinchilla定律**:模型规模与数据量最优比例(20参数/1 token) - 目标函数: - 掩码语言建模(MLM) - 自回归预测(AR) - 对比学习(SimCSE) 3. **对齐技术** - 三阶段流程: 1. 监督微调(SFT):精选指令数据优化行为[^2] 2. 奖励建模:构建人类偏好数据集 3. 强化学习(RLHF):PPO算法迭代优化 - **DPO优化**:直接偏好优化替代RLHF,训练效率提升6倍 --- ### 三、**核心挑战与解决方案** 1. **幻觉问题** - 根源分析: - 训练数据噪声(如天文术语错误[^1]) - 推理过度依赖统计先验 - 解决方案: - **检索增强(RAG)**:结合外部知识库 - **置信度校准**:DELLLA模型引入事实性评分机制 - **多模态对齐**:BLIP-2通过Q-Former约束视觉-语言关联[^2] 2. **长文本建模** - 位置编码:RoPE扩展至32k tokens(Code Llama) - 注意力优化:Ring Attention实现百万级上下文窗口 3. **多模态融合** - 视觉语言模型: - LLaVA:CLIP编码器+MLP适配器 - MiniGPT-4:两阶段训练缓解图文错位 - 视频理解: - Video-LLaMA:时间注意力模块捕捉时序特征 --- ### 四、**前沿技术融合** 1. **Agent系统** - 工具调用: - Toolformer:集成计算器、搜索引擎API - **司天大脑**:天文观测Agent处理多模态数据 - 自主决策:AutoGPT实现任务分解与递归执行 2. **MoE架构** - Mixtral-8x7B:稀疏激活机制,推理速度提升5倍 - 路由算法:Expert Choice实现专家负载均衡 3. **轻量化部署** - 量化技术: - GPTQ:4-bit量化精度损失<1% - AWQ:激活感知量化优化 - 服务框架: - vLLM:PagedAttention技术提升24倍吞吐量 --- ### 五、**关键概念补充** 1. **Token化机制** - Byte-Pair Encoding(BPE):GPT系列采用 - SentencePiece:多语言支持(如LLaMA) - Unicode规范化:处理多语言混合文本 2. **训练阶段策略** - **知识灌注阶段**:全量训练更新所有参数(如StarGLM天文知识注入) - **能力增强阶段**:轻量微调适配特定任务[^3] ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

山顶夕景

小哥哥给我买个零食可好

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值