Flash Attention的核心原理可以总结为:
- 动机和目标
传统的自注意力机制在transformer模型中存在计算和内存效率低下的问题,尤其是对于长序列输入。Flash Attention旨在通过优化数据布局和计算流程,降低注意力计算的内存访问开销,提高计算效率。 - 切块(Tiling)策略
Flash Attention将输入的查询(Query)、键(Key)和值(Value)矩阵切分成多个小块(tile),而不是一次性将整个矩阵加载到GPU内存中。这样可以充分利用GPU的有限内存带宽。 - 内存层次利用
Flash Attention将计算过程分散到GPU的不同内存层次(HBM和SRAM)。小的数据块被加载到高带宽的SRAM中进行计