flash_attention modules下的block、mha代码阅读笔记

FlashAttention:快速且内存高效的准确注意力机制

在深度学习领域,注意力机制是提高模型性能的关键组件。然而,传统的注意力机制在长序列处理时会消耗大量内存和计算资源。为了解决这个问题,Tri Dao等人提出了FlashAttention,这是一种快速且内存高效的注意力机制。本文将介绍FlashAttention及其改进版FlashAttention-2的核心概念、安装方法和使用示例。

论文介绍

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

安装和特性

环境要求

  • CUDA: 11.6及以上
  • PyTorch: 1.12及以上
  • 操作系统: Linux(从v2.3.2开始有部分Windows的正面反馈,但Windows编译仍需更多测试)

我们推荐使用Nvidia的PyTorch容器,其中包含安装FlashAttention所需的所有工具。

安装步骤

  1. 确保已安装PyTorch
  2. 安装packagingpip install packaging
  3. 安装ninja并确保其正常工作:ninja --version && echo $?应返回退出码0。如果未返回0,重新安装ninja:pip uninstall -y ninja && pip install ninja
使用pip安装
pip install flash-attn --no-build-isolation
从源码编译
python setup.py install
控制并行编译任务数(适用于RAM少于96GB且有多个CPU核心的机器)
MAX_JOBS=4 pip install flash-attn --no-build-isolation

使用示例

FlashAttention主要实现了缩放点积注意力(softmax(Q @ K^T * softmax_scale) @ V)。以下是使用FlashAttention的核心函数:

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

# 当Q, K, V已堆叠为一个张量时,使用flash_attn_qkvpacked_func
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
                                window_size=(-1, -1), alibi_slopes=None, deterministic=False)

# 直接使用Q, K, V时,使用flash_attn_func
out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
                      window_size=(-1, -1), alibi_slopes=None, deterministic=False)

参数说明

  • qkv: (batch_size, seqlen, 3, nheads, headdim)格式的张量,包含Q, K, V
  • dropout_p: float,Dropout概率
  • softmax_scale: float,softmax前QK^T的缩放比例,默认为1 / sqrt(headdim)
  • causal: bool,是否应用因果注意力掩码(如用于自回归建模)
  • window_size: (left, right),如果不为(-1, -1),则实现滑动窗口局部注意力
  • alibi_slopes: (nheads,)或(batch_size, nheads),fp32。对查询i和键j的注意力分数加上一个偏置(-alibi_slope * |i - j|)
  • deterministic: bool,是否使用确定性实现的反向传播(略慢且使用更多内存)

性能表现

加速效果

FlashAttention在A100 80GB SXM5 GPU上使用FP16/BF16格式时的加速效果如下:

  • Head Dimension: 64或128
  • Hidden Dimension: 2048(即32或16个heads)
  • Sequence Length: 512, 1k, 2k, 4k, 8k, 16k
  • Batch Size: 16k / seqlen

内存节省

FlashAttention在处理较长序列时能显著节省内存。与标准注意力机制内存使用随序列长度二次增长不同,FlashAttention的内存使用线性增长。在序列长度为2K时可节省10倍内存,4K时可节省20倍内存。

完整模型代码和训练脚本

已发布了完整的GPT模型实现,并提供了其他层(如MLP、LayerNorm、交叉熵损失、旋转嵌入)的优化实现。整体上,训练速度较基线实现(如Huggingface实现)提高3-5倍,达到每A100 225 TFLOPs/sec,相当于72%的模型FLOPs利用率。

FlashAttention 更新日志

2.0:完全重写,速度提升2倍

FlashAttention在2.0版本中进行了完全重写,速度提升了两倍。本次更新引入了多个更改和改进,包括一些函数名称的更改以及在输入具有相同序列长度的情况下简化了使用方式。

函数重命名

以下函数的名称已更新,以反映其更新后的功能:

  • flash_attn_unpadded_func -> flash_attn_varlen_func
  • flash_attn_unpadded_qkvpacked_func -> flash_attn_varlen_qkvpacked_func
  • flash_attn_unpadded_kvpacked_func -> flash_attn_varlen_kvpacked_func

如果输入在同一批次中具有相同的序列长度,使用以下函数将更加简单和快速:

  • flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
  • flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)

2.1:更改causal标志的行为

如果 seqlen_q != seqlen_k 并且 causal=True,则causal掩码将对齐到注意力矩阵的右下角,而不是左上角。

例如,如果 seqlen_q = 2seqlen_k = 5,则causal掩码(1 = 保留,0 = 掩盖)如下:

v2.0版本:

1 0 0 0 0
1 1 0 0 0

v2.1版本:

1 1 1 1 0
1 1 1 1 1

如果 seqlen_q = 5seqlen_k = 2,则causal掩码如下:

v2.0版本:

1 0
1 1
1 1
1 1
1 1

v2.1版本:

0 0
0 0
0 0
1 0
1 1

如果掩码的行全为零,则输出也将为零。

2.2:针对推理进行优化

在查询序列长度非常短(例如查询序列长度=1)的情况下,针对推理(迭代解码)进行优化。这里的瓶颈是尽可能快地加载KV缓存,我们通过不同线程块分割加载,并使用一个单独的内核来合并结果。

请参阅具有更多推理功能的 flash_attn_with_kvcache 函数(执行旋转嵌入,原地更新KV缓存)。

感谢xformers团队,特别是Daniel Haziza的合作。

2.3:局部(即滑动窗口)注意力

实现滑动窗口注意力(即局部注意力)。感谢Mistral AI团队,特别是Timothée Lacroix的贡献。滑动窗口被用于Mistral 7B模型中。

2.4:ALiBi(线性偏差注意力),确定性反向传播

实现ALiBi(Press等人,2021)。感谢Kakao Brain的Sanghun Cho的贡献。

实现确定性反向传播。感谢美团的工程师们的贡献。

2.5:分页KV缓存

支持分页KV缓存(即PagedAttention)。感谢 @beginlner 的贡献。

代码目录flash_attn/modules/block.py

代码解读:Block

在这篇博客中,我们将逐段解读 Block 类的代码。该类实现了一个通用的块结构,广泛应用于Transformer等模型中。我们将详细介绍代码实现流程和所用到的理论基础。

理论基础

在Transformer架构中,最基本的构件是编码器和解码器层(block)。每个层通常包括以下部分:

  1. 多头自注意力机制:用于计算每个词对其他词的注意力权重。
  2. 前馈神经网络:对每个词的表示进行非线性变换。
  3. 残差连接和层归一化:为了稳定训练,添加了残差连接和层归一化。

有两种常见的层结构:

  • Prenorm结构:层归一化在主要操作(注意力或前馈神经网络)之前应用。
  • Postnorm结构:层归一化在主要操作之后应用。
代码实现流程
class Block(nn.Module):
    def __init__(
        self,
        dim,
        mixer_cls=None,
        mlp_cls=None,
        norm_cls=nn.LayerNorm,
        dropout_cls=nn.Dropout,
        prenorm=True,
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        drop_path1=0.0,
        drop_path2=0.0,
        fused_dropout_add_ln=False,
        return_residual=False,
        residual_in_fp32=False,
        sequence_parallel=False,
        mark_shared_params=False,
    ):

这段代码定义了 Block 类的构造函数。以下是参数的解释:

  • dim:输入和输出的维度。
  • mixer_cls:用于计算注意力的类。
  • mlp_cls:用于前馈神经网络的类。
  • norm_cls:用于层归一化的类。
  • dropout_cls:用于Dropout的类。
  • prenorm:是否使用Prenorm结构。
  • resid_dropout1resid_dropout2:残差连接的Dropout率。
  • drop_path1drop_path2:用于Stochastic Depth的参数。
  • fused_dropout_add_ln:是否融合Dropout、Add和LayerNorm操作。
  • return_residual:是否在每个子层返回残差。
  • residual_in_fp32:是否使用FP32精度保存残差。
  • sequence_parallel:是否并行处理序列。
  • mark_shared_params:是否标记共享参数。
        super().__init__()
        self.prenorm = prenorm
        self.fused_dropout_add_ln = fused_dropout_add_ln
        self.return_residual = return_residual
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
            assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
        self.dropout1 = dropout_cls(resid_dropout1)
        self.drop_path1 = StochasticDepth(drop_path1, mode="row")
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
            self.dropout2 = dropout_cls(resid_dropout2)
            self.drop_path2 = StochasticDepth(drop_path2, mode="row")
            self.norm2 = norm_cls(dim)

在构造函数中,首先初始化了各个参数。根据 prenormfused_dropout_add_ln 等标志设置了一些断言和默认值。如果没有提供 mixer_clsmlp_cls,则使用默认的多头注意力机制(MHA)和前馈神经网络(Mlp)。

接下来,初始化了 mixerdropoutStochasticDepthnorm 层。

        if self.fused_dropout_add_ln:
            assert layer_norm_fn is not None, "Triton is not installed"
            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
                self.dropout1, nn.Dropout
            )

        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, "norm2"):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True

        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, "norm2"):
                for p in self.norm2.parameters():
                    p._shared_params = True

这段代码处理了 fused_dropout_add_lnsequence_parallelmark_shared_params 的情况。如果启用了 fused_dropout_add_ln,则确保安装了 Triton,并且 norm1dropout1 是有效类型。如果启用了 sequence_parallel,则将 norm1norm2 的参数标记为需要序列并行。如果启用了 mark_shared_params,则将这些参数标记为共享参数。

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

    def forward(
        self,
        hidden_states: Tensor,
        residual: Optional[Tensor] = None,
        mixer_subset=None,
        mixer_kwargs=None,
    ):

定义了两个方法:

  • allocate_inference_cache:为推理阶段分配缓存。
  • forward:前向传播函数,处理输入的 hidden_statesresidual
        if self.prenorm:
            if not self.fused_dropout_add_ln:
                dropped = self.drop_path1(self.dropout1(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
                hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(
                        torch.ones(
                            hidden_states.shape[:-1],
                            device=hidden_states.device,
                            dtype=hidden_states.dtype,
                        )
                    )
                hidden_states, residual = layer_norm_fn(
                    hidden_states,
                    self.norm1.weight,
                    self.norm1.bias,
                    residual=residual,
                    eps=self.norm1.eps,
                    dropout_p=self.dropout1.p if self.training else 0.0,
                    rowscale=rowscale1,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    is_rms_norm=isinstance(self.norm1, RMSNorm)
                )
            if mixer_kwargs is None:
                mixer_kwargs = {}
            if mixer_subset is not None:
                mixer_kwargs["mixer_subset"] = mixer_subset
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
            if not isinstance(self.mlp, nn.Identity):
                if not self.fused_dropout_add_ln:
                    dropped = self.drop_path2(self.dropout2(hidden_states))
                    residual = (dropped + residual) if residual is not None else dropped
                    hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                    if self.residual_in_fp32:
                        residual = residual.to(torch.float32)
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(
                            torch.ones(
                                hidden_states.shape[:-1],
                                device=hidden_states.device,
                                dtype=hidden_states.dtype,
                            )
                        )
                    hidden_states, residual = layer_norm_fn(
                        hidden_states,
                        self.norm2.weight,
                        self.norm2.bias,
                        residual=residual,
                        eps=self.norm2.eps,
                        dropout_p=self.dropout2.p if self.training else 0.0,
                        rowscale=rowscale2,
                        prenorm=True,
                        residual_in_fp32=self.residual_in_fp32,
                        is_rms_norm=isinstance(self.norm2, RMSNorm)
                    )
                hidden_states = self.mlp(hidden_states)
            return hidden_states, residual

如果使用 prenorm,则首先处理 dropoutresidual,然后应用 norm1。根据 fused_dropout_add_ln 的设置,选择是否融合这些操作。之后调用 mixer 层(通常是多头注意力机制)。最后,如果 mlp 不是 Identity,则进行类似的操作处理 mlp 层。

        else:
            assert residual is None
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out
            if not self

.fused_dropout_add_ln:
                hidden_states = self.norm1(
                    (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
                        dtype=self.norm1.weight.dtype
                    )
                )
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(
                        torch.ones(
                            mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
                        )
                    )
                hidden_states = layer_norm_fn(
                    mixer_out,
                    self.norm1.weight,
                    self.norm1.bias,
                    residual=hidden_states,
                    eps=self.norm1.eps,
                    dropout_p=self.dropout1.p if self.training else 0.0,
                    rowscale=rowscale1,
                    prenorm=False,
                    is_rms_norm=isinstance(self.norm1, RMSNorm)
                )
            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out
                if not self.fused_dropout_add_ln:
                    hidden_states = self.norm2(
                        (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
                            dtype=self.norm2.weight.dtype
                        )
                    )
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(
                            torch.ones(
                                mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
                            )
                        )
                    hidden_states = layer_norm_fn(
                        mlp_out,
                        self.norm2.weight,
                        self.norm2.bias,
                        residual=hidden_states,
                        eps=self.norm2.eps,
                        dropout_p=self.dropout2.p if self.training else 0.0,
                        rowscale=rowscale2,
                        prenorm=False,
                        is_rms_norm=isinstance(self.norm2, RMSNorm)
                    )
            return hidden_states

对于 postnorm 结构,处理流程类似,但 layer norm 应用于主要操作(注意力和前馈神经网络)之后。这里的 residual 在一开始设为 None。处理 mixer 层,之后处理 mlp 层。

通过这段代码,Block 类可以灵活地支持 prenormpostnorm 结构,以及各种Dropout、残差连接和层归一化的组合。这使得它在实现不同类型的Transformer架构时非常高效和通用。

代码解读博客:ParallelBlock

在这篇博客中,我们将逐段解读 ParallelBlock 类的代码。该类实现了并行的注意力(mixer)和MLP块,类似于GPT-J、GPT-NeoX和PaLM模型的结构。

理论基础

ParallelBlock 类采用了一种略有不同于常规Transformer块的结构。传统的Transformer块通常遵循以下结构:Layer Norm (LN) -> Multi-Head Attention (MHA) / MLP -> Dropout -> Add。而 ParallelBlock 中的结构为:Dropout -> Add -> LN -> MHA / MLP。这种结构的优势在于可以融合dropout、add和LayerNorm操作,从而提升性能。

代码实现流程
class ParallelBlock(nn.Module):
    """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
    and PaLM.
    """

    def __init__(
        self,
        dim,
        mixer_cls=None,
        mlp_cls=None,
        norm_cls=nn.LayerNorm,
        dropout_cls=nn.Dropout,
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        tied_norm=False,
        fused_dropout_add_ln=False,
        residual_in_fp32=False,
        sequence_parallel=False,
        mark_shared_params=False,
    ):
        super().__init__()
        self.tied_norm = tied_norm
        self.fused_dropout_add_ln = fused_dropout_add_ln
        self.residual_in_fp32 = residual_in_fp32
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
        self.dropout1 = dropout_cls(resid_dropout1)
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        self.dropout2 = dropout_cls(resid_dropout2)
        if not self.tied_norm:
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
            assert layer_norm_fn is not None, "Triton is not installed"
            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
                self.dropout1, nn.Dropout
            )

        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, "norm2"):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True

        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, "norm2"):
                for p in self.norm2.parameters():
                    p._shared_params = True

在构造函数中,首先初始化了各个参数。根据 tied_normfused_dropout_add_ln 等标志设置了一些断言和默认值。如果没有提供 mixer_clsmlp_cls,则使用默认的多头注意力机制(MHA)和前馈神经网络(Mlp)。初始化了 mixerdropoutnormmlp 层,并根据条件设置了层归一化参数的并行序列和共享标志。

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

    def forward(
        self,
        hidden_states1: Tensor,
        hidden_states2: Optional[Tensor] = None,
        residual: Optional[Tensor] = None,
        mixer_kwargs=None,
    ):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states1: the output of the previous attention (mixer) or embedding layer.
            hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
            residual.
        """
        if not self.fused_dropout_add_ln:
            dropped1 = self.dropout1(hidden_states1)
            if hidden_states2 is not None:
                dropped2 = self.dropout2(hidden_states2)
                residual = (
                    (residual + dropped1 + dropped2)
                    if residual is not None
                    else dropped1 + dropped2
                )
            else:
                residual = (residual + dropped1) if residual is not None else dropped1
            hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
            hidden_states2 = (
                self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                if not self.tied_norm
                else hidden_states1
            )
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            weight2, bias2 = (
                (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
            )
            hidden_states1, *rest, residual = layer_norm_fn(
                hidden_states1,
                self.norm1.weight,
                self.norm1.bias,
                residual=residual,
                x1=hidden_states2,
                weight1=weight2,
                bias1=bias2,
                eps=self.norm1.eps,
                dropout_p=self.dropout1.p if self.training else 0.0,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                is_rms_norm=isinstance(self.norm1, RMSNorm)
            )
            if self.tied_norm:
                hidden_states2 = hidden_states1
            else:
                hidden_states2, = rest

        if mixer_kwargs is None:
            mixer_kwargs = {}
        hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
        hidden_states2 = self.mlp(hidden_states2)
        return hidden_states1, hidden_states2, residual

forward 方法中,根据 fused_dropout_add_ln 的设置,选择是否融合dropout、add和LayerNorm操作。根据 hidden_states2 是否为 None,决定是否添加dropout到 residual 中。然后应用 norm1norm2,并处理残差。调用 mixermlp 层,最后返回 hidden_states1hidden_states2residual

通过这段代码,ParallelBlock 类实现了并行的注意力和MLP块结构,为模型的性能优化提供了一种有效的方法。

代码目录 flash_attn/modules/mha.py

代码解读:FlashSelfAttention

在这篇博客中,我们将逐段解读 FlashSelfAttention 类的代码。该类实现了一个带有 Softmax 的多头自注意力机制。我们将详细介绍代码实现流程和所用到的理论基础。

理论基础

自注意力机制是Transformer架构的核心。其主要原理是通过查询(query)、键(key)和值(value)来计算输入序列中每个元素的重要性权重,并根据这些权重对值进行加权求和。具体来说,自注意力机制通过以下步骤实现:

  1. 线性变换:将输入向量通过线性层分别映射到查询、键和值向量。
  2. 计算注意力权重:使用查询和键向量的点积计算注意力权重,通常会进行缩放并通过Softmax函数归一化。
  3. 加权求和:将归一化后的注意力权重与值向量相乘,得到输出向量。

多头注意力机制通过多个独立的注意力头来捕捉输入序列中的不同特征,并将这些特征拼接后再通过线性变换进行融合。

代码实现流程
class FlashSelfAttention(nn.Module):
    """实现带有Softmax的缩放点积注意力。
    参数
    ---------
        softmax_scale: 用于Softmax注意力的温度参数。
                      (默认值:1/sqrt(d_keys),其中d_keys在运行时计算)
        attention_dropout: 对注意力应用的Dropout率
                           (默认值:0.0)
    """

    def __init__(
        self,
        causal=False,
        softmax_scale=None,
        attention_dropout=0.0,
        window_size=(-1, -1),
        alibi_slopes=None,
        deterministic=False,
    ):
        super().__init__()
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention未安装"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention未安装"
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.drop = nn.Dropout(attention_dropout)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
        self.window_size = window_size
        self.deterministic = deterministic

这段代码定义了 FlashSelfAttention 类的构造函数。以下是参数的解释:

  • causal:是否为因果注意力,即是否考虑序列的时间顺序。
  • softmax_scale:Softmax的温度参数,用于缩放点积结果。
  • attention_dropout:注意力机制中的Dropout率。
  • window_size:局部窗口大小。
  • alibi_slopes:用于调整注意力偏置的斜率。
  • deterministic:是否使用确定性操作。

构造函数中首先通过断言确保所需的FlashAttention函数已安装,然后初始化各个参数,并通过 register_buffer 注册不需要梯度的参数。

    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
        """实现多头Softmax注意力。
        参数
        ---------
            qkv: 包含查询、键和值的张量。
                如果cu_seqlens为None且max_seqlen为None,则qkv形状为(B, S, 3, H, D)。
                如果cu_seqlens不为None且max_seqlen不为None,则qkv形状为(total, 3, H, D),
                其中total是批次中序列长度的总和。
            causal: 如果传递,将覆盖self.causal
            cu_seqlens: (batch_size + 1,) 形状的张量,类型为torch.int32。批次中序列的累计长度,用于索引到qkv中。
            max_seqlen: int。批次中最大序列长度。
        返回:
        --------
            out: 如果cu_seqlens不为None且max_seqlen不为None,则形状为(total, H, D),
                否则为(B, S, H, D)。
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
        causal = self.causal if causal is None else causal
        unpadded = cu_seqlens is not None
        if self.alibi_slopes is not None:
            self.alibi_slopes = self.alibi_slopes.to(torch.float32)
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            return flash_attn_varlen_qkvpacked_func(
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
                alibi_slopes=self.alibi_slopes,
                window_size=self.window_size,
                deterministic=self.deterministic,
            )
        else:
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
                alibi_slopes=self.alibi_slopes,
                window_size=self.window_size,
                deterministic=self.deterministic,
            )

这段代码实现了 forward 方法,即前向传播过程。以下是参数的解释:

  • qkv:包含查询、键和值的张量。
  • causal:如果传递,将覆盖 self.causal
  • cu_seqlens:批次中序列的累计长度,用于索引到 qkv 中。
  • max_seqlen:批次中最大序列长度。

在前向传播过程中,首先检查 qkv 的数据类型和设备类型。然后根据是否有 cu_seqlens 来确定是否使用未填充的序列。如果使用未填充的序列,则通过 flash_attn_varlen_qkvpacked_func 函数计算注意力;否则通过 flash_attn_qkvpacked_func 函数计算注意力。

代码小结

FlashSelfAttention 类实现了一个带有Softmax的多头自注意力机制。其主要步骤包括:

  1. 初始化参数。
  2. 在前向传播过程中,根据输入张量的形状和参数选择适当的函数计算注意力。

通过以上代码,我们可以高效地计算自注意力机制,并在需要时应用Dropout和因果注意力。

代码解读博客:FlashCrossAttention

在这篇博客中,我们将逐段解读 FlashCrossAttention 类的代码。该类实现了带有 Softmax 的缩放点积交叉注意力机制。我们将详细介绍代码实现流程和所用到的理论基础。

理论基础

交叉注意力机制与自注意力机制类似,但它使用不同的查询(query)、键(key)和值(value)来源于不同的序列。其主要步骤包括:

  1. 线性变换:将查询、键和值向量通过线性层进行映射。
  2. 计算注意力权重:使用查询和键向量的点积计算注意力权重,通常进行缩放并通过 Softmax 函数归一化。
  3. 加权求和:将归一化后的注意力权重与值向量相乘,得到输出向量。

交叉注意力在许多任务中具有广泛应用,如机器翻译中的编码器-解码器架构。

代码实现流程
class FlashCrossAttention(nn.Module):
    """实现带有Softmax的缩放点积注意力。
    参数
    ---------
        softmax_scale: 用于Softmax注意力的温度参数。
                      (默认值:1/sqrt(d_keys),其中d_keys在运行时计算)
        attention_dropout: 对注意力应用的Dropout率
                           (默认值:0.0)
    """

    def __init__(
        self,
        causal=False,
        softmax_scale=None,
        attention_dropout=0.0,
        alibi_slopes=None,
        window_size=(-1, -1),
        deterministic=False,
    ):
        super().__init__()
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention未安装"
        assert flash_attn_kvpacked_func is not None, "FlashAttention未安装"
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.drop = nn.Dropout(attention_dropout)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
        self.window_size = window_size
        self.deterministic = deterministic

这段代码定义了 FlashCrossAttention 类的构造函数。以下是参数的解释:

  • causal:是否为因果注意力,即是否考虑序列的时间顺序。
  • softmax_scale:Softmax的温度参数,用于缩放点积结果。
  • attention_dropout:注意力机制中的Dropout率。
  • window_size:局部窗口大小。
  • alibi_slopes:用于调整注意力偏置的斜率。
  • deterministic:是否使用确定性操作。

构造函数中首先通过断言确保所需的FlashAttention函数已安装,然后初始化各个参数,并通过 register_buffer 注册不需要梯度的参数。

    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
        """实现多头Softmax注意力。
        参数
        ---------
            q: 包含查询的张量。形状为 (B, Sq, H, D)
            kv: 包含键和值的张量。形状为 (B, Sk, 2, H_k, D)
            causal: 如果传递,将覆盖self.causal
            cu_seqlens: (batch_size + 1,) 形状的张量,类型为torch.int32。批次中序列的累计长度,用于索引到 q 中。
            max_seqlen: int。批次中 q 的最大序列长度。
            cu_seqlens_k: (batch_size + 1,) 形状的张量,类型为torch.int32。批次中序列的累计长度,用于索引到 kv 中。
            max_seqlen_k: int。批次中 k 和 v 的最大序列长度。
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
        causal = self.causal if causal is None else causal
        unpadded = cu_seqlens is not None
        if self.alibi_slopes is not None:
            self.alibi_slopes = self.alibi_slopes.to(torch.float32)
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
            return flash_attn_varlen_kvpacked_func(
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
                alibi_slopes=self.alibi_slopes,
                window_size=self.window_size,
                deterministic=self.deterministic,
            )
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
                alibi_slopes=self.alibi_slopes,
                window_size=self.window_size,
                deterministic=self.deterministic,
            )

这段代码实现了 forward 方法,即前向传播过程。以下是参数的解释:

  • q:包含查询的张量,形状为 (B, Sq, H, D)
  • kv:包含键和值的张量,形状为 (B, Sk, 2, H_k, D)
  • causal:如果传递,将覆盖 self.causal
  • cu_seqlens:批次中序列的累计长度,用于索引到 q 中。
  • max_seqlen:批次中 q 的最大序列长度。
  • cu_seqlens_k:批次中序列的累计长度,用于索引到 kv 中。
  • max_seqlen_k:批次中 kv 的最大序列长度。

在前向传播过程中,首先检查 qkv 的数据类型和设备类型。然后根据是否有 cu_seqlens 来确定是否使用未填充的序列。如果使用未填充的序列,则通过 flash_attn_varlen_kvpacked_func 函数计算注意力;否则通过 flash_attn_kvpacked_func 函数计算注意力。

代码小结

FlashCrossAttention 类实现了一个带有Softmax的多头交叉注意力机制。其主要步骤包括:

  1. 初始化参数。
  2. 在前向传播过程中,根据输入张量的形状和参数选择适当的函数计算注意力。

通过以上代码,我们可以高效地计算交叉注意力机制,并在需要时应用Dropout和因果注意力。

代码解读博客:SelfAttention

在这篇博客中,我们将逐段解读 SelfAttention 类的代码。该类实现了带有 Softmax 的缩放点积自注意力机制。我们将详细介绍代码实现流程和所用到的理论基础。

理论基础

自注意力机制是Transformer架构的核心,它允许模型在计算每个词的表示时考虑序列中的其他所有词。其主要步骤包括:

  1. 线性变换:将查询(query)、键(key)和值(value)向量通过线性层进行映射。
  2. 计算注意力权重:使用查询和键向量的点积计算注意力权重,通常进行缩放并通过Softmax函数归一化。
  3. 加权求和:将归一化后的注意力权重与值向量相乘,得到输出向量。
代码实现流程
class SelfAttention(nn.Module):
    """实现带有Softmax的缩放点积注意力。
    参数
    ---------
        softmax_scale: 用于Softmax注意力的温度参数。
                      (默认值:1/sqrt(d_keys),其中d_keys在运行时计算)
        attention_dropout: 对注意力应用的Dropout率
                           (默认值:0.0)
    """

    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.drop = nn.Dropout(attention_dropout)

这段代码定义了 SelfAttention 类的构造函数。以下是参数的解释:

  • causal:是否为因果注意力,即是否考虑序列的时间顺序。
  • softmax_scale:Softmax的温度参数,用于缩放点积结果。
  • attention_dropout:注意力机制中的Dropout率。

构造函数中初始化了因果注意力标志、Softmax缩放参数和Dropout层。

    def forward(self, qkv, causal=None, key_padding_mask=None):
        """实现多头Softmax注意力。
        参数
        ---------
            qkv: 包含查询、键和值的张量。形状为 (B, S, 3, H, D)
            causal: 如果传递,将覆盖self.causal
            key_padding_mask: 布尔掩码,用于对注意力权重进行掩码处理。True表示保留,False表示掩码。形状为 (B, S)
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
        causal = self.causal if causal is None else causal
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])

在前向传播过程中,首先获取批次大小和序列长度。如果传递了 causal 参数,则覆盖 self.causal。然后将 qkv 张量沿第三个维度进行拆分,得到查询、键和值。最后,计算Softmax的缩放参数。

        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)

这行代码使用爱因斯坦求和约定(einsum)计算查询和键的点积,并进行缩放。scores 张量形状为 (B, H, T, S),表示每个查询与所有键的相似度。

        if key_padding_mask is not None:
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")

如果提供了 key_padding_mask,则对注意力分数进行掩码处理。首先创建一个填充掩码,将指定位置的值设为一个非常小的数(例如-10000.0)。然后,将这个掩码应用到注意力分数上,掩盖掉不需要考虑的值。

        if causal:
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)

如果是因果注意力,则构造一个上三角掩码(只保留对角线及其以下的元素),掩盖掉未来的时间步。这个掩码用于确保当前时间步只能关注自己和之前的时间步,防止信息泄漏。

        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
        attention_drop = self.drop(attention)
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
        return output

计算Softmax得到注意力权重,然后应用Dropout。最后,使用注意力权重对值进行加权求和,得到最终的输出。

代码小结

SelfAttention 类实现了一个带有Softmax的多头自注意力机制。其主要步骤包括:

  1. 初始化参数。
  2. 在前向传播过程中,根据输入张量计算注意力分数。
  3. 应用键填充掩码和因果掩码。
  4. 计算Softmax得到注意力权重,并应用Dropout。
  5. 使用注意力权重对值进行加权求和,得到最终输出。

结论

FlashAttention及其改进版FlashAttention-2为注意力机制在深度学习中的应用提供了显著的速度和内存优化,使得处理长序列数据变得更加高效。希望本文对您了解和使用FlashAttention有所帮助。


如果您对FlashAttention有任何问题或建议,欢迎通过GitHub issue与我们联系。

参考链接:

  • 12
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这段代码定义了两个自注意力机制的子类:`CrossAttention` 和 `GlobalSelfAttention`。这两个子类都继承了一个基础的注意力层 `BaseAttention`。 `BaseAttention` 类中定义了注意力层的基本结构。它包含了一个多头注意力层(`MultiHeadAttention`),一个层归一化层(`LayerNormalization`)和一个加法层(`Add`)。其中,多头注意力层用于计算注意力权重和上下文向量,层归一化层用于规范化输入向量,加法层用于将输入向量与上下文向量相加。 `CrossAttention` 类是 `BaseAttention` 的子类,在其基础上增加了一个 `call()` 方法。该方法用于执行跨注意力操作,接收两个输入张量 `x` 和 `context`,并使用多头注意力层计算 `x` 相对于 `context` 的注意力权重和上下文向量。然后,通过加法层和层归一化层将输入向量和上下文向量相加,并返回结果。 `GlobalSelfAttention` 类也是 `BaseAttention` 的子类,它实现了全局自注意力操作。在 `call()` 方法中,它接收一个输入张量 `x`,并使用多头注意力层计算 `x` 自身的注意力权重和上下文向量。然后,通过加法层和层归一化层将输入向量和上下文向量相加,并返回结果。 这段代码使用了 TensorFlow 框架的 `tf.keras.layers` 模块来定义注意力层的结构。你可以根据自己的需求进一步使用这些类来构建注意力机制的模型。请注意,这只是代码片段的一部分,可能还需要根据具体的模型和任务进行适当的修改和调整。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值