flash-attention保姆级安装教程

FlashAttention安装教程

FlashAttention 是一种高效且内存优化的注意力机制实现,旨在提升大规模深度学习模型的训练和推理效率。

  • 高效计算:通过优化 IO 操作,减少内存访问开销,提升计算效率。

  • 内存优化:降低内存占用,使得在大规模模型上运行更加可行。

  • 精确注意力:保持注意力机制的精确性,不引入近似误差。

  • FlashAttention-2 是 FlashAttention 的升级版本,优化了并行计算策略,充分利用硬件资源。改进了工作负载分配,进一步提升计算效率。

  • FlashAttention-3:FlashAttention-3 是专为 Hopper GPU(如 H100)优化的版本,目前处于 Beta 测试阶段。


常见问题:
安装成功后,实际模型代码运行时报错未安装,核心原因就是cxx11abiFALSE这个参数,表示该包在构建时不启用 C++11 ABI。
必须开启不使用才行。否则报错如下:
ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn.


最佳安装步骤(方法1)

  1. 安装依赖
    • 基础环境:cuda12.1、nvcc.
    • 安装python,示例3.10。
    • 安装PyTorch,示例torch2.3.0; torchvision0.18.0
    • ninja Python 包
  2. 获取releases对应的whl包
    - 地址:https://github.com/Dao-AILab/flash-attention/releases
    - 按照系统环境选whl
    在这里插入图片描述
    3. 我的环境对应的包是:flash_attn-2.7.2.post1+cu12torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl,解释如下:
    • flash_attn: 包的名称,表示这个 Wheel 文件是 flash_attn 包的安装文件。
    • 2.7.2.post1: 包的版本号,遵循 PEP 440 版本规范。
      • 2.7.2: 主版本号,表示这是 flash_attn 的第 2.7.2 版本。
      • post1: 表示这是一个“后发布版本”(post-release),通常用于修复发布后的某些问题。
    • +cu12torch2.3cxx11abiFALSE: 构建标签,表示该 Wheel 文件是在特定环境下构建的。
      • cu12: 表示该包是针对 CUDA 12 构建的。
      • torch2.3: 表示该包是针对 PyTorch 2.3 构建的。
      • cxx11abiFALSE: 表示该包在构建时不启用 C++11 ABI(Application Binary Interface)。如果安装包后不识别,就要选为False的版本。
    • cp310: Python 版本的标签,表示该包是为 Python 3.10 构建的。
      • cp310: 是 cpython 3.10 的缩写,表示该包适用于 CPython 解释器的 3.10 版本。
    • linux_x86_64: 平台标签,表示该包是为 Linux 操作系统和 x86_64 架构(即 64 位 Intel/AMD 处理器)构建的。
    • .whl: 文件扩展名,表示这是一个 Python Wheel 文件。Wheel 是 Python 的一种二进制分发格式,用于快速安装包。

如何安装

可以使用 pip 安装这个 Wheel 文件:

pip install flash_attn-2.7.2.post1+cu12torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl --no-build-isolation


常规安装步骤(方法二)

  1. 安装依赖

    • CUDA 工具包或 ROCm 工具包
    • PyTorch 1.12 及以上版本
    • packagingninja Python 包
    pip install packaging ninja
    
  2. 安装 FlashAttention

    # 后面--no-build-isolation参数是为了pip 会直接在当前环境中构建包,使用当前环境中已安装的依赖项。
    # 如果当前环境缺少构建所需的依赖项,构建过程可能会失败。
    pip install flash-attn --no-build-isolation
    

    或从源码编译:

    # 下载源码后,进行编译
    cd flash-attention
    python setup.py install
    
  3. 运行测试

    export PYTHONPATH=$PWD
    pytest -q -s test_flash_attn.py
    
  4. 补充说明

    4.1 上面运行时,建议设置参数MAX_JOBS,限制最大进程数,不然系统容易崩。本人在docker下安装,直接干重启了,所以建议如下方式运行:

    MAX_JOBS=4 pip install flash-attn --no-build-isolation
    

    4.2 如果运行时会出现警告且推理速度依旧很慢,需要继续从源码安装rotary和layer_norm,cd到源码的那两个文件夹,执行 python setup.py install进行安装,如果命令报错弃用,可能要用easy_install命令。
    在这里插入图片描述

接口使用

import flash_attn_interface
flash_attn_interface.flash_attn_func()

硬件支持

NVIDIA CUDA 支持

  • 支持 GPU:Ampere、Ada 或 Hopper 架构 GPU(如 A100、RTX 3090、RTX 4090、H100)。
  • 数据类型:FP16 和 BF16。
  • 头维度:支持所有头维度,最大至 256。

AMD ROCm 支持

  • 支持 GPU:MI200 或 MI300 系列 GPU。
  • 数据类型:FP16 和 BF16。
  • 后端:支持 Composable Kernel (CK) 和 Triton 后端。

性能优化

Triton 后端

Triton 后端的 FlashAttention-2 实现仍在开发中,目前支持以下特性:

  • 前向和反向传播:支持因果掩码、变长序列、任意 Q 和 KV 序列长度、任意头大小。
  • 多查询和分组查询注意力:目前仅支持前向传播,反向传播支持正在开发中。

性能改进

  • 并行编译:使用 ninja 工具进行并行编译,显著减少编译时间。
  • 内存管理:通过设置 MAX_JOBS 环境变量,限制并行编译任务数量,避免内存耗尽。

结论

FlashAttention 系列通过优化计算和内存使用,显著提升了注意力机制的效率。无论是研究人员还是工程师,都可以通过本文提供的安装和使用指南,快速上手并应用于实际项目中。随着 FlashAttention-3 的推出,针对 Hopper GPU 的优化将进一步推动大规模深度学习模型的发展。

参考链接


ImportError Traceback (most recent call last) Cell In[4], line 4 2 import numpy as np 3 import matplotlib.pyplot as plt ----> 4 import seaborn as sns 5 import matplotlib as mpl File D:\tool\anaconda3\lib\site-packages\seaborn\__init__.py:2 1 # Import seaborn objects ----> 2 from .rcmod import * # noqa: F401,F403 3 from .utils import * # noqa: F401,F403 4 from .palettes import * # noqa: F401,F403 File D:\tool\anaconda3\lib\site-packages\seaborn\rcmod.py:5 3 import matplotlib as mpl 4 from cycler import cycler ----> 5 from . import palettes 8 __all__ = ["set_theme", "set", "reset_defaults", "reset_orig", 9 "axes_style", "set_style", "plotting_context", "set_context", 10 "set_palette"] 13 _style_keys = [ 14 15 "axes.facecolor", (...) 50 51 ] File D:\tool\anaconda3\lib\site-packages\seaborn\palettes.py:9 5 import matplotlib as mpl 7 from .external import husl ----> 9 from .utils import desaturate, get_color_cycle 10 from .colors import xkcd_rgb, crayons 11 from ._compat import get_colormap File D:\tool\anaconda3\lib\site-packages\seaborn\utils.py:17 14 import matplotlib.pyplot as plt 15 from matplotlib.cbook import normalize_kwargs ---> 17 from seaborn._core.typing import deprecated 18 from seaborn.external.version import Version 19 from seaborn.external.appdirs import user_cache_dir File D:\tool\anaconda3\lib\site-packages\seaborn\_core\typing.py:8 5 from typing import Any, Optional, Union, Tuple, List, Dict 7 from numpy import ndarray # TODO use ArrayLike? ----> 8 from pandas import Series, Index, Timestamp, Timedelta 9 from matplotlib.colors import Colormap, Normalize 12 ColumnName = Union[ 13 str, bytes, date, datetime, timedelta, bool, complex, Timestamp, Timedelta 14 ] ImportError: cannot import name 'Series' from 'pandas' (unkn
04-03
### 解决方案 当遇到 `ImportError: cannot import name 'Series' from 'pandas'` 时,这通常表明 Pandas 库存在版本冲突或安装不完整的情况。以下是可能的原因以及解决方案: #### 可能原因分析 1. **Pandas 版本问题**: 当前使用的 Pandas 版本可能存在兼容性问题[^1]。 2. **依赖库损坏**: 安装过程中某些文件未正确下载或被覆盖[^3]。 3. **环境污染**: Anaconda 或其他 Python 环境中的包管理混乱可能导致此类错误。 --- #### 解决方法 ##### 方法一:更新 Pandas 和 Seaborn 到最新版本 确保 Pandas 和 Seaborn 是最新的稳定版本,因为旧版本可能存在已修复的 bug。 ```bash pip install --upgrade pandas seaborn ``` 如果使用的是 Conda,则可以运行以下命令: ```bash conda update pandas seaborn ``` ##### 方法二:重新创建虚拟环境并安装必要依赖 有时现有环境可能会受到污染,建议通过以下方式清理环境: 1. 创建一个新的 Conda 虚拟环境: ```bash conda create -n new_env python=3.9 ``` 2. 激活新环境: ```bash conda activate new_env ``` 3. 在新环境中安装必要的库: ```bash conda install pandas seaborn numpy scipy matplotlib ``` ##### 方法三:手动卸载并重装 Pandas 如果上述方法未能解决问题,尝试完全移除 Pandas 并重新安装: ```bash pip uninstall pandas pip install pandas ``` 或者对于 Conda 用户: ```bash conda remove pandas conda install pandas ``` ##### 方法四:检查特定模块是否存在 确认 Pandas 中确实包含 `Series` 类型定义。可以通过以下代码验证: ```python import pandas as pd print(pd.Series) ``` 如果没有抛出异常,则说明当前 Pandas 安装正常;否则可能是安装过程出现问题。 ##### 方法五:降到稳定的 Pandas 版本 部分情况下,较新的 Pandas 版本可能与其他库(如 NumPy、Seaborn)存在兼容性问题。可以选择回退到更早的稳定版: ```bash pip install pandas==1.5.3 ``` --- ### 注意事项 - 如果仍然无法解决,请提供完整的 Traceback 错误日志以便进一步诊断。 - 避免混用 `pip` 和 `conda` 命令来安装同一套件,以免引发路径冲突。 ---
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

LensonYuan

蚊子腿也是肉!感谢!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值