论文地址:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
FA主要思路还是通过tile技术减少在HBM和on-chip SRAM内存读写时间。FA在bert-large上端到端训练有15%的加速(seq length 512), 在GPT-2有3倍加速(seq length 1024)。
主要贡献点
- 计算softmax时候不需要全量input数据,可以分段计算。
- 反向传播的时候,不存储attention matrix (N^2的矩阵),而是只存储softmax归一化的系数。
基于 FlashAttention 技术,清华将基座模型的上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段