FlashAttention:快速且内存高效的准确注意力机制
在深度学习领域,注意力机制是提高模型性能的关键组件。然而,传统的注意力机制在长序列处理时会消耗大量内存和计算资源。为了解决这个问题,Tri Dao等人提出了FlashAttention,这是一种快速且内存高效的注意力机制。本文将介绍FlashAttention及其改进版FlashAttention-2的核心概念、安装方法和使用示例。
论文介绍
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- 作者: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
- 论文链接: arxiv.org/abs/2205.14135
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- 作者: Tri Dao
- 论文链接: flash2.pdf
安装和特性
环境要求
- CUDA: 11.6及以上
- PyTorch: 1.12及以上
- 操作系统: Linux(从v2.3.2开始有部分Windows的正面反馈,但Windows编译仍需更多测试)
我们推荐使用Nvidia的PyTorch容器,其中包含安装FlashAttention所需的所有工具。
安装步骤
- 确保已安装PyTorch
- 安装
packaging
:pip install packaging
- 安装
ninja
并确保其正常工作:ninja --version && echo $?
应返回退出码0。如果未返回0,重新安装ninja:pip uninstall -y ninja && pip install ninja
使用pip安装
pip install flash-attn --no-build-isolation
从源码编译
python setup.py install
控制并行编译任务数(适用于RAM少于96GB且有多个CPU核心的机器)
MAX_JOBS=4 pip install flash-attn --no-build-isolation
使用示例
FlashAttention主要实现了缩放点积注意力(softmax(Q @ K^T * softmax_scale) @ V)。以下是使用FlashAttention的核心函数:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# 当Q, K, V已堆叠为一个张量时,使用flash_attn_qkvpacked_func
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False)
# 直接使用Q, K, V时,使用flash_attn_func
out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False)
参数说明
qkv
:(batch_size, seqlen, 3, nheads, headdim)
格式的张量,包含Q, K, Vdropout_p
: float,Dropout概率softmax_scale
: float,softmax前QK^T的缩放比例,默认为1 / sqrt(headdim)causal
: bool,是否应用因果注意力掩码(如用于自回归建模)window_size
: (left, right),如果不为(-1, -1),则实现滑动窗口局部注意力alibi_slopes
: (nheads,)或(batch_size, nheads),fp32。对查询i和键j的注意力分数加上一个偏置(-alibi_slope * |i - j|)deterministic
: bool,是否使用确定性实现的反向传播(略慢且使用更多内存)
性能表现
加速效果
FlashAttention在A100 80GB SXM5 GPU上使用FP16/BF16格式时的加速效果如下:
- Head Dimension: 64或128
- Hidden Dimension: 2048(即32或16个heads)
- Sequence Length: 512, 1k, 2k, 4k, 8k, 16k
- Batch Size: 16k / seqlen
内存节省
FlashAttention在处理较长序列时能显著节省内存。与标准注意力机制内存使用随序列长度二次增长不同,FlashAttention的内存使用线性增长。在序列长度为2K时可节省10倍内存,4K时可节省20倍内存。
完整模型代码和训练脚本
已发布了完整的GPT模型实现,并提供了其他层(如MLP、LayerNorm、交叉熵损失、旋转嵌入)的优化实现。整体上,训练速度较基线实现(如Huggingface实现)提高3-5倍,达到每A100 225 TFLOPs/sec,相当于72%的模型FLOPs利用率。
FlashAttention 更新日志
2.0:完全重写,速度提升2倍
FlashAttention在2.0版本中进行了完全重写,速度提升了两倍。本次更新引入了多个更改和改进,包括一些函数名称的更改以及在输入具有相同序列长度的情况下简化了使用方式。
函数重命名
以下函数的名称已更新,以反映其更新后的功能:
flash_attn_unpadded_func
->flash_attn_varlen_func
flash_attn_unpadded_qkvpacked_func
->flash_attn_varlen_qkvpacked_func
flash_attn_unpadded_kvpacked_func
->flash_attn_varlen_kvpacked_func
如果输入在同一批次中具有相同的序列长度,使用以下函数将更加简单和快速:
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
2.1:更改causal标志的行为
如果 seqlen_q != seqlen_k
并且 causal=True
,则causal掩码将对齐到注意力矩阵的右下角,而不是左上角。
例如,如果 seqlen_q = 2
且 seqlen_k = 5
,则causal掩码(1 = 保留,0 = 掩盖)如下:
v2.0版本:
1 0 0 0 0
1 1 0 0 0
v2.1版本:
1 1 1 1 0
1 1 1 1 1
如果 seqlen_q = 5
且 seqlen_k = 2
,则causal掩码如下:
v2.0版本:
1 0
1 1
1 1
1 1
1 1
v2.1版本:
0 0
0 0
0 0
1 0
1 1
如果掩码的行全为零,则输出也将为零。
2.2:针对推理进行优化
在查询序列长度非常短(例如查询序列长度=1)的情况下,针对推理(迭代解码)进行优化。这里的瓶颈是尽可能快地加载KV缓存,我们通过不同线程块分割加载,并使用一个单独的内核来合并结果。
请参阅具有更多推理功能的 flash_attn_with_kvcache
函数(执行旋转嵌入,原地更新KV缓存)。
感谢xformers团队,特别是Daniel Haziza的合作。
2.3:局部(即滑动窗口)注意力
实现滑动窗口注意力(即局部注意力)。感谢Mistral AI团队,特别是Timothée Lacroix的贡献。滑动窗口被用于Mistral 7B模型中。
2.4:ALiBi(线性偏差注意力),确定性反向传播
实现ALiBi(Press等人,2021)。感谢Kakao Brain的Sanghun Cho的贡献。
实现确定性反向传播。感谢美团的工程师们的贡献。
2.5:分页KV缓存
支持分页KV缓存(即PagedAttention)。感谢 @beginlner 的贡献。
代码目录flash_attn/modules/block.py
代码解读:Block
在这篇博客中,我们将逐段解读 Block
类的代码。该类实现了一个通用的块结构,广泛应用于Transformer等模型中。我们将详细介绍代码实现流程和所用到的理论基础。
理论基础
在Transformer架构中,最基本的构件是编码器和解码器层(block)。每个层通常包括以下部分:
- 多头自注意力机制:用于计算每个词对其他词的注意力权重。
- 前馈神经网络:对每个词的表示进行非线性变换。
- 残差连接和层归一化:为了稳定训练,添加了残差连接和层归一化。
有两种常见的层结构:
- Prenorm结构:层归一化在主要操作(注意力或前馈神经网络)之前应用。
- Postnorm结构:层归一化在主要操作之后应用。
代码实现流程
class Block(nn.Module):
def __init__(
self,
dim,
mixer_cls=None,
mlp_cls=None,
norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout,
prenorm=True,
resid_dropout1=0.0,
resid_dropout2=0.0,
drop_path1=0.0,
drop_path2=0.0,
fused_dropout_add_ln=False,
return_residual=False,
residual_in_fp32=False,
sequence_parallel=False,
mark_shared_params=False,
):
这段代码定义了 Block
类的构造函数。以下是参数的解释:
dim
:输入和输出的维度。mixer_cls
:用于计算注意力的类。mlp_cls
:用于前馈神经网络的类。norm_cls
:用于层归一化的类。dropout_cls
:用于Dropout的类。prenorm
:是否使用Prenorm结构。resid_dropout1
和resid_dropout2
:残差连接的Dropout率。drop_path1
和drop_path2
:用于Stochastic Depth的参数。fused_dropout_add_ln
:是否融合Dropout、Add和LayerNorm操作。return_residual
:是否在每个子层返回残差。residual_in_fp32
:是否使用FP32精度保存残差。sequence_parallel
:是否并行处理序列。mark_shared_params
:是否标记共享参数。
super().__init__()
self.prenorm = prenorm
self.fused_dropout_add_ln = fused_dropout_add_ln
self.return_residual = return_residual
self.residual_in_fp32 = residual_in_fp32
if self.residual_in_fp32:
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2)
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
self.norm2 = norm_cls(dim)
在构造函数中,首先初始化了各个参数。根据 prenorm
、fused_dropout_add_ln
等标志设置了一些断言和默认值。如果没有提供 mixer_cls
和 mlp_cls
,则使用默认的多头注意力机制(MHA)和前馈神经网络(Mlp)。
接下来,初始化了 mixer
、dropout
、StochasticDepth
和 norm
层。
if self.fused_dropout_add_ln:
assert layer_norm_fn is not None, "Triton is not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
self.dropout1, nn.Dropout
)
if sequence_parallel:
for p in self.norm1.parameters():
p._sequence_parallel = True
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._sequence_parallel = True
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._shared_params = True
这段代码处理了 fused_dropout_add_ln
、sequence_parallel
和 mark_shared_params
的情况。如果启用了 fused_dropout_add_ln
,则确保安装了 Triton,并且 norm1
和 dropout1
是有效类型。如果启用了 sequence_parallel
,则将 norm1
和 norm2
的参数标记为需要序列并行。如果启用了 mark_shared_params
,则将这些参数标记为共享参数。
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(
self,
hidden_states: Tensor,
residual: Optional[Tensor] = None,
mixer_subset=None,
mixer_kwargs=None,
):
定义了两个方法:
allocate_inference_cache
:为推理阶段分配缓存。forward
:前向传播函数,处理输入的hidden_states
和residual
。
if self.prenorm:
if not self.fused_dropout_add_ln:
dropped = self.drop_path1(self.dropout1(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm1.weight,
self.norm1.bias,
residual=residual,
eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
rowscale=rowscale1,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm1, RMSNorm)
)
if mixer_kwargs is None:
mixer_kwargs = {}
if mixer_subset is not None:
mixer_kwargs["mixer_subset"] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm2.weight,
self.norm2.bias,
residual=residual,
eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm2, RMSNorm)
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
如果使用 prenorm
,则首先处理 dropout
和 residual
,然后应用 norm1
。根据 fused_dropout_add_ln
的设置,选择是否融合这些操作。之后调用 mixer
层(通常是多头注意力机制)。最后,如果 mlp
不是 Identity
,则进行类似的操作处理 mlp
层。
else:
assert residual is None
mixer_out = self.mixer(
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
)
if self.return_residual: # mixer out is actually a pair here
mixer_out, hidden_states = mixer_out
if not self
.fused_dropout_add_ln:
hidden_states = self.norm1(
(self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
dtype=self.norm1.weight.dtype
)
)
else:
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(
torch.ones(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
)
)
hidden_states = layer_norm_fn(
mixer_out,
self.norm1.weight,
self.norm1.bias,
residual=hidden_states,
eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
rowscale=rowscale1,
prenorm=False,
is_rms_norm=isinstance(self.norm1, RMSNorm)
)
if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states)
if self.return_residual: # mlp out is actually a pair here
mlp_out, hidden_states = mlp_out
if not self.fused_dropout_add_ln:
hidden_states = self.norm2(
(self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
dtype=self.norm2.weight.dtype
)
)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(
torch.ones(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
)
)
hidden_states = layer_norm_fn(
mlp_out,
self.norm2.weight,
self.norm2.bias,
residual=hidden_states,
eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2,
prenorm=False,
is_rms_norm=isinstance(self.norm2, RMSNorm)
)
return hidden_states
对于 postnorm
结构,处理流程类似,但 layer norm
应用于主要操作(注意力和前馈神经网络)之后。这里的 residual
在一开始设为 None
。处理 mixer
层,之后处理 mlp
层。
通过这段代码,Block
类可以灵活地支持 prenorm
和 postnorm
结构,以及各种Dropout、残差连接和层归一化的组合。这使得它在实现不同类型的Transformer架构时非常高效和通用。
代码解读博客:ParallelBlock
在这篇博客中,我们将逐段解读 ParallelBlock
类的代码。该类实现了并行的注意力(mixer)和MLP块,类似于GPT-J、GPT-NeoX和PaLM模型的结构。
理论基础
ParallelBlock
类采用了一种略有不同于常规Transformer块的结构。传统的Transformer块通常遵循以下结构:Layer Norm (LN) -> Multi-Head Attention (MHA) / MLP -> Dropout -> Add。而 ParallelBlock
中的结构为:Dropout -> Add -> LN -> MHA / MLP。这种结构的优势在于可以融合dropout、add和LayerNorm操作,从而提升性能。
代码实现流程
class ParallelBlock(nn.Module):
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
and PaLM.
"""
def __init__(
self,
dim,
mixer_cls=None,
mlp_cls=None,
norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout,
resid_dropout1=0.0,
resid_dropout2=0.0,
tied_norm=False,
fused_dropout_add_ln=False,
residual_in_fp32=False,
sequence_parallel=False,
mark_shared_params=False,
):
super().__init__()
self.tied_norm = tied_norm
self.fused_dropout_add_ln = fused_dropout_add_ln
self.residual_in_fp32 = residual_in_fp32
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
self.dropout2 = dropout_cls(resid_dropout2)
if not self.tied_norm:
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert layer_norm_fn is not None, "Triton is not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
self.dropout1, nn.Dropout
)
if sequence_parallel:
for p in self.norm1.parameters():
p._sequence_parallel = True
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._sequence_parallel = True
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._shared_params = True
在构造函数中,首先初始化了各个参数。根据 tied_norm
、fused_dropout_add_ln
等标志设置了一些断言和默认值。如果没有提供 mixer_cls
和 mlp_cls
,则使用默认的多头注意力机制(MHA)和前馈神经网络(Mlp)。初始化了 mixer
、dropout
、norm
和 mlp
层,并根据条件设置了层归一化参数的并行序列和共享标志。
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(
self,
hidden_states1: Tensor,
hidden_states2: Optional[Tensor] = None,
residual: Optional[Tensor] = None,
mixer_kwargs=None,
):
r"""Pass the input through the encoder layer.
Args:
hidden_states1: the output of the previous attention (mixer) or embedding layer.
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
if not self.fused_dropout_add_ln:
dropped1 = self.dropout1(hidden_states1)
if hidden_states2 is not None:
dropped2 = self.dropout2(hidden_states2)
residual = (
(residual + dropped1 + dropped2)
if residual is not None
else dropped1 + dropped2
)
else:
residual = (residual + dropped1) if residual is not None else dropped1
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
hidden_states2 = (
self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm
else hidden_states1
)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
weight2, bias2 = (
(self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
)
hidden_states1, *rest, residual = layer_norm_fn(
hidden_states1,
self.norm1.weight,
self.norm1.bias,
residual=residual,
x1=hidden_states2,
weight1=weight2,
bias1=bias2,
eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm1, RMSNorm)
)
if self.tied_norm:
hidden_states2 = hidden_states1
else:
hidden_states2, = rest
if mixer_kwargs is None:
mixer_kwargs = {}
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
hidden_states2 = self.mlp(hidden_states2)
return hidden_states1, hidden_states2, residual
在 forward
方法中,根据 fused_dropout_add_ln
的设置,选择是否融合dropout、add和LayerNorm操作。根据 hidden_states2
是否为 None
,决定是否添加dropout到 residual
中。然后应用 norm1
和 norm2
,并处理残差。调用 mixer
和 mlp
层,最后返回 hidden_states1
、hidden_states2
和 residual
。
通过这段代码,ParallelBlock
类实现了并行的注意力和MLP块结构,为模型的性能优化提供了一种有效的方法。
代码目录 flash_attn/modules/mha.py
代码解读:FlashSelfAttention
在这篇博客中,我们将逐段解读 FlashSelfAttention
类的代码。该类实现了一个带有 Softmax 的多头自注意力机制。我们将详细介绍代码实现流程和所用到的理论基础。
理论基础
自注意力机制是Transformer架构的核心。其主要原理是通过查询(query)、键(key)和值(value)来计算输入序列中每个元素的重要性权重,并根据这些权重对值进行加权求和。具体来说,自注意力机制通过以下步骤实现:
- 线性变换:将输入向量通过线性层分别映射到查询、键和值向量。
- 计算注意力权重:使用查询和键向量的点积计算注意力权重,通常会进行缩放并通过Softmax函数归一化。
- 加权求和:将归一化后的注意力权重与值向量相乘,得到输出向量。
多头注意力机制通过多个独立的注意力头来捕捉输入序列中的不同特征,并将这些特征拼接后再通过线性变换进行融合。
代码实现流程
class FlashSelfAttention(nn.Module):
"""实现带有Softmax的缩放点积注意力。
参数
---------
softmax_scale: 用于Softmax注意力的温度参数。
(默认值:1/sqrt(d_keys),其中d_keys在运行时计算)
attention_dropout: 对注意力应用的Dropout率
(默认值:0.0)
"""
def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
):
super().__init__()
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention未安装"
assert flash_attn_qkvpacked_func is not None, "FlashAttention未安装"
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.window_size = window_size
self.deterministic = deterministic
这段代码定义了 FlashSelfAttention
类的构造函数。以下是参数的解释:
causal
:是否为因果注意力,即是否考虑序列的时间顺序。softmax_scale
:Softmax的温度参数,用于缩放点积结果。attention_dropout
:注意力机制中的Dropout率。window_size
:局部窗口大小。alibi_slopes
:用于调整注意力偏置的斜率。deterministic
:是否使用确定性操作。
构造函数中首先通过断言确保所需的FlashAttention函数已安装,然后初始化各个参数,并通过 register_buffer
注册不需要梯度的参数。
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
"""实现多头Softmax注意力。
参数
---------
qkv: 包含查询、键和值的张量。
如果cu_seqlens为None且max_seqlen为None,则qkv形状为(B, S, 3, H, D)。
如果cu_seqlens不为None且max_seqlen不为None,则qkv形状为(total, 3, H, D),
其中total是批次中序列长度的总和。
causal: 如果传递,将覆盖self.causal
cu_seqlens: (batch_size + 1,) 形状的张量,类型为torch.int32。批次中序列的累计长度,用于索引到qkv中。
max_seqlen: int。批次中最大序列长度。
返回:
--------
out: 如果cu_seqlens不为None且max_seqlen不为None,则形状为(total, H, D),
否则为(B, S, H, D)。
"""
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if self.alibi_slopes is not None:
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
return flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic,
)
else:
return flash_attn_qkvpacked_func(
qkv,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic,
)
这段代码实现了 forward
方法,即前向传播过程。以下是参数的解释:
qkv
:包含查询、键和值的张量。causal
:如果传递,将覆盖self.causal
。cu_seqlens
:批次中序列的累计长度,用于索引到qkv
中。max_seqlen
:批次中最大序列长度。
在前向传播过程中,首先检查 qkv
的数据类型和设备类型。然后根据是否有 cu_seqlens
来确定是否使用未填充的序列。如果使用未填充的序列,则通过 flash_attn_varlen_qkvpacked_func
函数计算注意力;否则通过 flash_attn_qkvpacked_func
函数计算注意力。
代码小结
FlashSelfAttention
类实现了一个带有Softmax的多头自注意力机制。其主要步骤包括:
- 初始化参数。
- 在前向传播过程中,根据输入张量的形状和参数选择适当的函数计算注意力。
通过以上代码,我们可以高效地计算自注意力机制,并在需要时应用Dropout和因果注意力。
代码解读博客:FlashCrossAttention
在这篇博客中,我们将逐段解读 FlashCrossAttention
类的代码。该类实现了带有 Softmax 的缩放点积交叉注意力机制。我们将详细介绍代码实现流程和所用到的理论基础。
理论基础
交叉注意力机制与自注意力机制类似,但它使用不同的查询(query)、键(key)和值(value)来源于不同的序列。其主要步骤包括:
- 线性变换:将查询、键和值向量通过线性层进行映射。
- 计算注意力权重:使用查询和键向量的点积计算注意力权重,通常进行缩放并通过 Softmax 函数归一化。
- 加权求和:将归一化后的注意力权重与值向量相乘,得到输出向量。
交叉注意力在许多任务中具有广泛应用,如机器翻译中的编码器-解码器架构。
代码实现流程
class FlashCrossAttention(nn.Module):
"""实现带有Softmax的缩放点积注意力。
参数
---------
softmax_scale: 用于Softmax注意力的温度参数。
(默认值:1/sqrt(d_keys),其中d_keys在运行时计算)
attention_dropout: 对注意力应用的Dropout率
(默认值:0.0)
"""
def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
alibi_slopes=None,
window_size=(-1, -1),
deterministic=False,
):
super().__init__()
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention未安装"
assert flash_attn_kvpacked_func is not None, "FlashAttention未安装"
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.window_size = window_size
self.deterministic = deterministic
这段代码定义了 FlashCrossAttention
类的构造函数。以下是参数的解释:
causal
:是否为因果注意力,即是否考虑序列的时间顺序。softmax_scale
:Softmax的温度参数,用于缩放点积结果。attention_dropout
:注意力机制中的Dropout率。window_size
:局部窗口大小。alibi_slopes
:用于调整注意力偏置的斜率。deterministic
:是否使用确定性操作。
构造函数中首先通过断言确保所需的FlashAttention函数已安装,然后初始化各个参数,并通过 register_buffer
注册不需要梯度的参数。
def forward(
self,
q,
kv,
causal=None,
cu_seqlens=None,
max_seqlen=None,
cu_seqlens_k=None,
max_seqlen_k=None,
):
"""实现多头Softmax注意力。
参数
---------
q: 包含查询的张量。形状为 (B, Sq, H, D)
kv: 包含键和值的张量。形状为 (B, Sk, 2, H_k, D)
causal: 如果传递,将覆盖self.causal
cu_seqlens: (batch_size + 1,) 形状的张量,类型为torch.int32。批次中序列的累计长度,用于索引到 q 中。
max_seqlen: int。批次中 q 的最大序列长度。
cu_seqlens_k: (batch_size + 1,) 形状的张量,类型为torch.int32。批次中序列的累计长度,用于索引到 kv 中。
max_seqlen_k: int。批次中 k 和 v 的最大序列长度。
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda and kv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if self.alibi_slopes is not None:
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
assert cu_seqlens_k is not None
assert cu_seqlens_k.dtype == torch.int32
assert max_seqlen_k is not None
assert isinstance(max_seqlen, int)
return flash_attn_varlen_kvpacked_func(
q,
kv,
cu_seqlens,
cu_seqlens_k,
max_seqlen,
max_seqlen_k,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic,
)
else:
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
return flash_attn_kvpacked_func(
q,
kv,
self.drop.p if self.training else 0.0,
causal=causal,
softmax_scale=self.softmax_scale,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic,
)
这段代码实现了 forward
方法,即前向传播过程。以下是参数的解释:
q
:包含查询的张量,形状为(B, Sq, H, D)
。kv
:包含键和值的张量,形状为(B, Sk, 2, H_k, D)
。causal
:如果传递,将覆盖self.causal
。cu_seqlens
:批次中序列的累计长度,用于索引到q
中。max_seqlen
:批次中q
的最大序列长度。cu_seqlens_k
:批次中序列的累计长度,用于索引到kv
中。max_seqlen_k
:批次中kv
的最大序列长度。
在前向传播过程中,首先检查 q
和 kv
的数据类型和设备类型。然后根据是否有 cu_seqlens
来确定是否使用未填充的序列。如果使用未填充的序列,则通过 flash_attn_varlen_kvpacked_func
函数计算注意力;否则通过 flash_attn_kvpacked_func
函数计算注意力。
代码小结
FlashCrossAttention
类实现了一个带有Softmax的多头交叉注意力机制。其主要步骤包括:
- 初始化参数。
- 在前向传播过程中,根据输入张量的形状和参数选择适当的函数计算注意力。
通过以上代码,我们可以高效地计算交叉注意力机制,并在需要时应用Dropout和因果注意力。
代码解读博客:SelfAttention
在这篇博客中,我们将逐段解读 SelfAttention
类的代码。该类实现了带有 Softmax 的缩放点积自注意力机制。我们将详细介绍代码实现流程和所用到的理论基础。
理论基础
自注意力机制是Transformer架构的核心,它允许模型在计算每个词的表示时考虑序列中的其他所有词。其主要步骤包括:
- 线性变换:将查询(query)、键(key)和值(value)向量通过线性层进行映射。
- 计算注意力权重:使用查询和键向量的点积计算注意力权重,通常进行缩放并通过Softmax函数归一化。
- 加权求和:将归一化后的注意力权重与值向量相乘,得到输出向量。
代码实现流程
class SelfAttention(nn.Module):
"""实现带有Softmax的缩放点积注意力。
参数
---------
softmax_scale: 用于Softmax注意力的温度参数。
(默认值:1/sqrt(d_keys),其中d_keys在运行时计算)
attention_dropout: 对注意力应用的Dropout率
(默认值:0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
这段代码定义了 SelfAttention
类的构造函数。以下是参数的解释:
causal
:是否为因果注意力,即是否考虑序列的时间顺序。softmax_scale
:Softmax的温度参数,用于缩放点积结果。attention_dropout
:注意力机制中的Dropout率。
构造函数中初始化了因果注意力标志、Softmax缩放参数和Dropout层。
def forward(self, qkv, causal=None, key_padding_mask=None):
"""实现多头Softmax注意力。
参数
---------
qkv: 包含查询、键和值的张量。形状为 (B, S, 3, H, D)
causal: 如果传递,将覆盖self.causal
key_padding_mask: 布尔掩码,用于对注意力权重进行掩码处理。True表示保留,False表示掩码。形状为 (B, S)
"""
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
causal = self.causal if causal is None else causal
q, k, v = qkv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
在前向传播过程中,首先获取批次大小和序列长度。如果传递了 causal
参数,则覆盖 self.causal
。然后将 qkv
张量沿第三个维度进行拆分,得到查询、键和值。最后,计算Softmax的缩放参数。
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
这行代码使用爱因斯坦求和约定(einsum)计算查询和键的点积,并进行缩放。scores
张量形状为 (B, H, T, S)
,表示每个查询与所有键的相似度。
if key_padding_mask is not None:
padding_mask = torch.full(
(batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
如果提供了 key_padding_mask
,则对注意力分数进行掩码处理。首先创建一个填充掩码,将指定位置的值设为一个非常小的数(例如-10000.0)。然后,将这个掩码应用到注意力分数上,掩盖掉不需要考虑的值。
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(
torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
如果是因果注意力,则构造一个上三角掩码(只保留对角线及其以下的元素),掩盖掉未来的时间步。这个掩码用于确保当前时间步只能关注自己和之前的时间步,防止信息泄漏。
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = self.drop(attention)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
return output
计算Softmax得到注意力权重,然后应用Dropout。最后,使用注意力权重对值进行加权求和,得到最终的输出。
代码小结
SelfAttention
类实现了一个带有Softmax的多头自注意力机制。其主要步骤包括:
- 初始化参数。
- 在前向传播过程中,根据输入张量计算注意力分数。
- 应用键填充掩码和因果掩码。
- 计算Softmax得到注意力权重,并应用Dropout。
- 使用注意力权重对值进行加权求和,得到最终输出。
结论
FlashAttention及其改进版FlashAttention-2为注意力机制在深度学习中的应用提供了显著的速度和内存优化,使得处理长序列数据变得更加高效。希望本文对您了解和使用FlashAttention有所帮助。
如果您对FlashAttention有任何问题或建议,欢迎通过GitHub issue与我们联系。
参考链接: