flash-attention代码逻辑

  1. setup.py:python项目中,setup.py用于管理项目的构建、打包和分发过程。这个文件通常包含项目的元数据以及如何构建和安装模块的指令
    • 三个相关命令
      • 构建扩展模块:python setup.py build_ext
      • 清理构建文件:python setup.py clean
      • 安装到系统:python setup.py install。在项目根目录下,通过运行该命令来构建和安装你的包,这将会执行setup.py文件中的setup()函数,并根据其中的配置将包构建成一个分发包,并安装到python环境中
    • 运行python setup.py install后发生的事情:
      • 环境检查:python检查setup里面列出的依赖项是否已经安装。若没有则尝试安装
      • 构建包:使用find_packages()找到所有可用的子模块并准备构建
      • 编译扩展:如果有C/C++扩展模块,使用指定的构建工具(如Ninja)来编译这些扩展
      • 安装包:将包和所有依赖项安装到python的site-packages目录,使得包可以在python中被导入和使用
      • 验证安装:安装完后,用户可在python环境中使用import PACKAGE_NAME来验证安装是否成功

也就是,setup.py就是为了把编译后的结果打包成一个python包然后安装在环境当中的。setup.py其中包含了编译流程(ext_modules),等运行完之后,用户可在python环境中使用import PACKAGE_NAME来验证安装是否成功

setup(
    name=PACKAGE_NAME,
    version=get_package_version(),
    packages=find_packages(  // 用于查找包中可分发的所有子模块。exclude参数指定要排除的目录,这些目录不会被打包。通常会排除测试、文档和构建目录
        exclude=(
            "build",
            "csrc",
            "include",
            "tests",
            "dist",
            "docs",
            "benchmarks",
            "flash_attn.egg-info",
        )
    ),
    author="Tri Dao",
    author_email="tri@tridao.me",
    description="Flash Attention: Fast and Memory-Efficient Exact Attention",
    long_description=long_description,
    long_description_content_type="text/markdown",
    url="https://github.com/Dao-AILab/flash-attention",
    classifiers=[  // 一组字符串,用于提供关于包的元数据,比如python版本、许可证类型和操作系统
        "Programming Language :: Python :: 3",
        "License :: OSI Approved :: BSD License",
        "Operating System :: Unix",
    ],
    ext_modules=ext_modules,  // 指定C/C++扩展模块,如果没有扩展模块通常设为None。如果有C/C++扩展模块,就使用的构建工具(如Ninja)来编译这些扩展
    cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}  // 用于定义命令的字典
    if ext_modules
    else {
        "bdist_wheel": CachedWheelsCommand,
    },
    python_requires=">=3.8",
    install_requires=[
        "torch",
        "einops",
    ],
    setup_requires=[ 
        "packaging",
        "psutil",
        "ninja",
    ],
)

  1. “编译”与ext_modules
    • 编译:
      如上面所说,运行python setup.py install的过程会检查是否有C/C++扩展模块,若有的话就进行编译。

      具体来说,编译扩展是将用C/C++编写的代码编译成共享库(动态链接库),这个库可以被python直接导入和使用。这使得python能够调用高性能的底层代码,通常用于加速计算密集型任务。

      编译完成后,生成的共享库通常会是一个.so(Linux)、.dll(Windows)或.dylib(macOS)结尾的文件,这些文件可以在python中通过import语句直接导入。
    • ext_modules:是一个列表,包含了所有需要编译的扩展模块。通常由setuptoolsExtension类构建(from setuptools import Extension)。这里是使用from torch.utils.cpp_extension import CUDAExtention。在setup()函数中,ext_modules参数指向这个扩展模块列表,当用户运行python setup.py install时,setuptools会读取这些信息,调用编译器进行编译。如果定义了多个扩展模块,它们会在同一次构建过程中被编译并链接到最终的python包中。

      编译后的扩展模块可以被python代码直接调用,就像普通的python模块一样。

      如下面,name是“flash_attn_2_cuda”的意思就是编译好的库怎么引用呢,就是通过import flash_attn_2_cuda来引用。
    ext_modules.append(
        CUDAExtension(
            name="flash_attn_2_cuda",
            sources=[
                "csrc/flash_attn/flash_api.cpp",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
               "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
            ],
            extra_compile_args={
                "cxx": ["-O3", "-std=c++17"] + generator_flag,
                "nvcc": append_nvcc_threads(
                    [
                        "-O3",
                        "-std=c++17",
                        "-U__CUDA_NO_HALF_OPERATORS__",
                        "-U__CUDA_NO_HALF_CONVERSIONS__",
                        "-U__CUDA_NO_HALF2_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                        "--expt-relaxed-constexpr",
                        "--expt-extended-lambda",
                        "--use_fast_math",
                        # "--ptxas-options=-v",
                        # "--ptxas-options=-O2",
                        # "-lineinfo",
                        # "-DFLASHATTENTION_DISABLE_BACKWARD",
                        # "-DFLASHATTENTION_DISABLE_DROPOUT",
                        # "-DFLASHATTENTION_DISABLE_ALIBI",
                        # "-DFLASHATTENTION_DISABLE_SOFTCAP",
                        # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
                        # "-DFLASHATTENTION_DISABLE_LOCAL",
                    ]
                    + generator_flag
                    + cc_flag
                ),
            },
            include_dirs=[
                Path(this_dir) / "csrc" / "flash_attn",
                Path(this_dir) / "csrc" / "flash_attn" / "src",
                Path(this_dir) / "csrc" / "cutlass" / "include",
            ],
        )
    )
    
    
    ext_modules.append(
        CUDAExtension(
            name="flash_attn_2_cuda",
            sources=renamed_sources,
            extra_compile_args=extra_compile_args,
            include_dirs=include_dirs,
        )
    )
    
    • 通过编译扩展,开发者可以利用C/C++的性能优势,同时保持python的易用性,这对于需要高性能计算的应用尤为重要
  2. torch.utils.cpp_extension.CUDAExtension介绍
    是pytorch提供的一个类,用于方便地构建和编译CUDA扩展。它封装了与CUDA相关的编译过程,允许用户在pytorch中轻松集成自定义的CUDA代码
    • 几个功能:
      • 编译CUDA代码:允许用户指定CUDA源文件及相关的编译选项,从而生成可以在python中使用的共享库
      • 集成C++代码:用户可以将C++代码与CUDA代码结合,创建复杂的扩展
      • 简化配置:提供了一种简单的方法来管理编译过程中的各种设置,如头文件路径、库文件、编译器标志等
    • 使用方法:
      from torch.utils.cpp_extension import CUDAExtension, setup
      
      ext_modules = [
          CUDAExtension(
              name='my_cuda_extension',  # 模块名称
              sources=['src/my_cuda_extension.cpp',  # 源文件。即包含实际代码的文件,定义了要实现的功能或算法
                       'src/my_cuda_extension_kernel.cu'],
              include_dirs=['/path/to/include'],  # 包含头文件的目录。包含了函数声明、宏定义和数据结构的定义。头文件使得不同源文件可以共享和复用代码
              libraries=['mylib'],  # 链接的库。是编译时需要引用的外部库,它们提供额外的功能,通常是在编译的过程中
                                    # 与扩展模块进行链接。链接库可以是静态库(.a文件)或动态库(.so或.dll文件)
              library_dirs=['/path/to/lib'],  # 库文件路径。指存放链接库的目录。当编译器在链接阶段寻找库文件时会使用这个路径
              extra_compile_args={
                  "cxx": ["-O3", "-std=c++17"] + generator_flag,  # -03:启用最高级别的优化,通常会生成更快但是编译时间更长的代码
                                                                  # -std=c++17:指定使用C++17标准
                                                                  # +generator_flag:追加其他生成器特定的编译选项。generator_flag通常是动态定义的,可能与编译器或构建工具有关
                                                                  # 前面定义了 generator_flag = ["-DOLD_GENERATOR_PATH"]
                  "nvcc": append_nvcc_threads(  # 这里包含了为nvcc(nvidia CUDA编译器)指定的编译选项
                      [
                          "-O3", 
                          "-std=c++17",
                          "-U__CUDA_NO_HALF_OPERATORS__",
                          "-U__CUDA_NO_HALF_CONVERSIONS__",
                          "-U__CUDA_NO_HALF2_OPERATORS__",
                          "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                          "--expt-relaxed-constexpr",
                          "--expt-extended-lambda",
                          "--use_fast_math",  # 启用快速数学库以提高性能,但可能以牺牲准确性为代价
                          # "--ptxas-options=-v",  # 编译时显示ptxas的详细信息,有助于调试
                          # "--ptxas-options=-O2",
                          # "-lineinfo",
                          # "-DFLASHATTENTION_DISABLE_BACKWARD",
                          # "-DFLASHATTENTION_DISABLE_DROPOUT",
                          # "-DFLASHATTENTION_DISABLE_ALIBI",
                          # "-DFLASHATTENTION_DISABLE_SOFTCAP",
                          # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
                          # "-DFLASHATTENTION_DISABLE_LOCAL",
                      ]
                      + generator_flag
                      + cc_flag
                  ),
          	},
          )
      ]
      
  3. 所以,要看的就是sources里的文件,这些就是要编译的CUDA源文件
    • 它们实现了不同版本的前向和反向传播算法:fp16/bf16、fwd/bwd、hdim、causal、split
    • flash_api.cpp:Flash Attention API 的定义和实现,用于提供 Python 和 CUDA 代码之间的接口。
    • flash_fwd_hdimXX_fp16_sm80.cu:这些是 CUDA 源文件,涉及前向计算的实现,hdimXX 表示模型的隐藏维度(例如,32, 64, 96, 128, 160, 192, 256),fp16 指使用16位半精度浮点数(另外还有bf16),sm80 指该文件是为特定的 CUDA 架构(例如,80对应于 Ampere架构)编写的
    • flash_fwd_hdimXX_fp16_causal_sm80.cu:这些文件是针对因果前向计算的实现(含掩码),适用于语言模型等需要因果注意力的任务。它们同样根据不同的隐藏维度和数据类型进行分类
    • flash_bwd_hdimXX_fp16_sm80.cu:实现了backward反向传播的计算,用于训练过程中的梯度计算
    • flash_bwd_hdimXX_fp16_causal_sm80.cu:实现了因果模型的反向传播
    • flash_fwd_split_hdimXX_fp16_sm80.cu:实现了针对特定隐藏维度的分割前向计算,可能是为了更高效地处理大型输入(??)
    • flash_fwd_split_hdimXX_fp16_causal_sm80.cu
  4. 所以改的话,就是改fwd、causal=false、(看下默认参数配置?
  5. flash_api.cpp
    • set_params_fprop
    • set_params_dgrad
    • run_mha_fwd
    • num_splits_heuristic
    • set_params_splitkv
    • set_params_alibi
    • mha_fwd
    • mha_varlen_fwd
    • run_mha_bwd
    • mha_bwd
    • mha_varlen_bwd
    • mha_fwd_kvcache
    • pybind定义:
      PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
          m.doc() = "FlashAttention";
          m.def("fwd", &mha_fwd, "Forward pass");  // 定义一个名为fwd的函数,绑定到上面的mha_fwd函数,并为该函数提供文档字符串“Forward pass”,这表示该函数实现了前向传播的计算逻辑
          m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
          m.def("bwd", &mha_bwd, "Backward pass");
          m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
          m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
      }
      
  6. 总体调用流程:
    • from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
    • 编译好的.so文件,使用其.fwd

    从一个CUDA+Python联合调试的文章里清晰了解了一个CUDA项目的编译过程:
    原始项目的目录树为:
    请添加图片描述
    其中,cuda_hello.cu是待调试的CUDA代码(里面定义了一个打印hello的核函数和一个主机端调用接口launch_cuda_hello);

    pybind_wrapper.cpp使用pybind11这个库将CUDA代码中的主机调用接口函数注册到Python中(具体就是,先创建一个名为cuda_hello的python模块,然后将外部的主机函数launch_cuda_hello与新建python包中的函数名hello关联。最终在python中的使用方法就是:import cuda_hello,然后cuda_hello.hello())。如下,PYBIND11_MODULEpybind11提供的宏,用于定义一个python模块,下面的代码中,模块名设为cuda_hello,并传入了m作为模块对象的引用,通过m为这个模块添加函数和类:
    PYBIND11_MODULE(cuda_hello, m) { m.def("hello", &launch_cuda_hello, "A function that launches a CUDA kernel to print Hello"); }

    test_cuda_hello.py中,通过动态链接库导入cuda_hello这个包,并通过上述方法调用该包中的launch_cuda_hello函数
    import cuda_hello
    cuda_hello.hello()

    CMakeLists.txt文件中,设置CUDA标准、CUDA架构、C++ 标准等一系列配置,以及配置刚刚定义的编译源代码:查找pybind11包、添加CUDA源代码并创建共享库(add_library(cuda_functions SHARED src/cuda_hello.cu))、创建pybind11模块(pybind11_add_module(cuda_hello src/pybind_wrapper.cpp))、将CUDA函数库链接到pybind11模块(target_link_libraries(cuda_hello PRIVATE cuda_functions))。
    即准备好pybind11->把cuda源文件打包成共享库->用pybind11创建一个python模块->将cuda共享库链接到python模块中,使python模块能执行GPU代码

    • .fwd绑定的是flash_api.cpp中的mha_fwd函数
    • mha_fwd在完成初始化后,调用run_mha_fwd(params, stream)(依然定义在flash_api.cpp中)进行前向计算
    • run_mha_fwd会根据 – 1)数据类型(params.is_bf16)、2)维度(params.d)、3)是否采用causal attention(params.is_causal) – 来调用run_mha_fwd_函数(或若force_split_kernel,调用run_mha_fwd_splitkv_dispatch函数)并传入elem_typekHeadDimIs_causal三个参数
      • run_mha_fwd_函数声明在flash.h中(在flash_api.cpp中要include flash.h):
        template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
        
    • flash_fwd_launch_template.h介绍:通过宏定义和模板参数来生成不同变体的内核函数,从而适配不同的硬件架构、输入条件和操作模式
      • 包含头文件:主要涉及CUDA上下文、flash-attention计算
        • #include <ATen/cuda/CUDAContext.h>:是pytorch中的一个头文件,这个文件定义了与CUDA相关的上下文管理功能,主要用于处理CUDA设备的初始化、设备上下文切换以及流管理。
          • ATen是pytorch的底层tensor库,提供了tensor计算、自动求导等基础功能
          • CUDAContext:负责设备初始化、设备选择、与CUDA流有关的操作(CUDA流允许在GPU上并行执行多个任务)、以及与CUDA相关的资源管理
          • 该头文件是pytorch中实现GPU加速计算的关键部分
            #include <ATen/cuda/CUDAContext.h>
            
            // 获取当前 CUDA 设备信息
            int current_device = at::cuda::current_device();
            
            // 切换 CUDA 设备
            at::cuda::set_device(0);
            
            // 获取默认 CUDA 流
            cudaStream_t stream = at::cuda::getCurrentCUDAStream();
            
        • #include "static_switch.h":通过一系列宏定义(如FP16_SWITCHHEADDIM_SWITCHBOOL_SWITCH)来简化和优化在编译时的条件分支处理。这些宏根据布尔或其他条件,在编译或运行时选择执行不同的代
        • #include "flash.h"
          • 定义如下结构体:Qkv_paramsFlash_fwd_paramsFlash_bwd_params
          • 定义如下函数模版:run_mha_fwd_run_mha_fwd_splitkv_dispatchrun_mha_bwd_
        • #include "flash_fwd_kernel.h":主要就是进行attention的计算,且本头文件中定义的函数都放在namespace flash下面。具体定义如下函数:
          • get_lse_tile
          • compute_attn:计算attention的外部逻辑函数,它会先获取块索引,然后调用compute_attn_1rowblock并将之前定义的参数和当前块索引传进去,进行实际的单行attention计算
          • compute_attn_splitkv:和上面compute_attn的原理差不多,区别就是它支持split kv机制,能适应多头注意力的复杂需求,能通过分割逻辑优化性能
          • compute_attn_1rowblock:用于计算单个行块(row block)上的attention
          • compute_attn_1rowblock_splitkv
          • combine_attn_seqk_parallel:结合多个attention头的计算结果,以计算最终的输出
      • 定义了三个核函数:flash_fwd_kernelflash_fwd_splitkv_kernelflash_fwd_splitkv_combine_kernel。分别调用flash::compute_attnflash::compute_attn_splitkvflash::combine_attn_seqk_parallel进行attention的计算
      • 定义了三个主机函数run_flash_fwdrun_flash_splitkv_fwdrun_mha_fwd_splitkv_dispatch,分别调用了上面的flash_fwd_kernelflash_fwd_splitkv_kernelflash_fwd_splitkv_combine_kernel
      • 定义了不同维度的主机函数:run_mha_fwd_hdim32run_mha_fwd_hdim64run_mha_fwd_hdim96run_mha_fwd_hdim128run_mha_fwd_hdim160run_mha_fwd_hdim192run_mha_fwd_hdim256,会调用run_flash_fwd
      • 也就是flash_fwd_launch_template实际上是包裹了flash_fwd_kernel.h的实现(现在还未知是从外部的哪里调用了flash_fwd_launch_template,以及内部flash_fwd_kernel.h具体是如何实现的,如果没啥问题,应该就是改这个头文件了。但是有个疑问,就是他函数逻辑是定义在一个头文件里?)
src中的包裹逻辑

flash_fwd_kernel.h:实现一行块一行块的attention

flash_fwd_launch_template.h:实现不同维度的run_mha_fwd_hdim256,进行run_flash_fwd函数的调用。run_flash_fwd再根据其他参数进行flash_fwd_kernel的调用,核函数flash_fwd_kernel会调用flash_fwd_kernel.h中的具体计算逻辑

具体在每个flash_fwd_hdim?_bf16?_sm80.cu文件中,会include上面的flash_fwd_launch_template.h,然后具体定义run_mha_fwd_函数:根据参数来调用具体的填满维度的函数,如run_mha_fwd_hdim96

最终在外部接口flash_api.cpp中,调用run_mha_fwd_函数

结论,所以改的话,只需要看flash_fwd_launch_template.h(每准这个也不用改)和flash_fwd_kernel.h即可。前者是分配了不同维度,后者是具体的计算

src中一些概念性定义的头文件
  1. kernel_traits.h:定义了三个结构体
    • struct Flash_kernel_traits:封装了不同CUDA架构的特性和操作,包括定义别名、定义MMA(矩阵乘法原子)、定义SmemCopyAtom和SmemCopyAtomTransposed(共享内存复制原子)
    • struct Flash_fwd_kernel_traits : public Base:继承了上面的 struct Flash_kernel_traits,并在前向计算中增加了特定的优化和数据布局方式。总的来说,这个结构体是对flash attention前向计算核函数的执行特性进行描述的,其描述了在GPU上计算attention时所设计的关键参数、内存布局和优化策略。结构体描述的内容包括:
      说白了作用就是根据根据CUDA架构选择不同的内存布局、复制方式、核函数参数(如KNThreadskBlockM等参数控制核函数执行时的线程数和块大小,确保核函数适合在不同的矩阵大小和head_dim下执行)和矩阵运算原子。
      • 线程和块大小:定义了核函数执行时的线程数、线程块大小、并行计算的warp数,这些参数决定了计算过程中每个线程处理的数据量等
      • 内存布局和访问模式:描述了Q、K、V矩阵在shared memory和global memory中的布局方式(SmemLayoutQSmemLayoutKVGmemLayoutAtom等),通过这些布局来确保在GPU内存结构中高效读取和写入数据;同时使用特定的复制方式(SmemCopyAtomGmemTiledCopyQKV)来减少共享内存的冲突和优化全局内存的带宽使用
      • 架构优化:根据不同的硬件架构选择不同的优化策略,如是否使用cp.async进行异步数据传输、根据是fp16还是bf16来选择不同的矩阵乘算法(MMA_Atom_Arch
      • attention优化:如使用kHeadDim定义了头部维度如何影响内存分配和复制方式,特别是在不同数据分块策略下,确保高效的矩阵乘法和内存操作
    • struct Flash_bwd_kernel_traits: public Base
flash_fwd_kernel.h 的具体实现
inline device void compute_attn_1rowblock(const Params &params, const int bidb, const int bidi, const int m_block)
  1. Kernel_traits
    • flash_fwd_kernel.h:在模板函数中定义typename Kernel_traits
    • flash_fwd_launch_template.h-flash_fwd_kernel核函数:flash_fwd_kernel核函数是模板函数(该模板函数又是通过宏来定义的,即通过宏定义固定格式生成多个核函数,然后此flash_fwd_kernel核函数又通过自身模板函数的特性,可传入不同类型参数/不同参数值并在编译时就确定其值),其中就有typename Kernel_traits,进而在该核函数里通过调用flash_fwd_kernel.h中具体的attention计算函数来进行Kernel_traits的传递(传给上面)
      flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, ...>(params);
      
    • flash_fwd_launch_template.h-run_flash_fwd主机函数:run_flash_fwd也是模板函数,定义了typename Kernel_traits,进而在该主机函数里通过调用上面的flash_fwd_kernel核函数来进行Kernel_traits的传递
      // run_flash_fwd函数的定义如下:
      template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
      void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream){...}
      
      // run_flash_fwd函数中具体调用上面核函数的部分代码如下:
      auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout&&!Is_softcap,...>;
      kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
      
    • flash_fwd_launch_template.h-run_mha_fwd_hdim?:以run_mha_fwd_hdim64为例,该函数会调用上面的run_flash_fwd函数:
      constexpr static int Headdim = 64;
      run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);  
      
      这就找到了Kernel_traits了。根据上面run_flash_fwd的函数定义可知,Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>就是具体传入的Kernel_traits。这是一个定义在kernel_traits.h中的结构体,在flash_fwd_launch_template.h中存在#include "flash_fwd_kernel.h",在flash_fwd_kernel.h中存在#include "kernel_traits.h",所以这里可以直接使用
  2. 具体该函数内执行q、k矩阵乘的部分:
    请添加图片描述
    然后又调用了这里

    最后调用了cute::gemm,就是cutlass的实现了
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值