浅谈FlashAttention优化原理

FlashAttention原理简介

背景:

在Transformer 结构中,自注意力机制的时间和存储复杂度与序列的长度呈平方的关系,因此占用了大量的计算设备内存和并消耗大量计算资源。如何优化自注意力机制的时空复杂度增强计算效率是大语言模型需要面临的重要问题

1.FlashAttention

NVIDIA GPU 中的内存(显存)按照它们物理上是在GPU 芯片内部还是板卡RAM 存储芯片上,决定了它们的速度、大小以及访问限制。

GPU 显存分为:全局内存(Global memory)、本地内存(Local memory)、共享内存(Shared memory,SRAM)、寄存器内存(Register memory)、常量内存(Constant memory)、纹理内存(Texture memory)等六大类。

下图给出了NVIDIA GPU 内存的整体结构。其中全局内存、本地内存、共享内存和寄存器内存具有读写能力。全局内存和本地内存使用的高带宽显存(High Bandwidth Memory,HBM)位于板卡RAM 存储芯片上,该部分内存容量很大。

全局内存是所有线程都可以访问,而本地内存则只能当前线程访问。NVIDIA H100 中全局内存有80GB 空间,其访问速度虽然可以达到3.35TB/s,但是如果全部线程同时访问全局内存时,其平均带宽仍然很低。

共享内存和寄存器位于GPU 芯片上,因此容量很小,并且共享内存只有在同一个GPU 线程块(Thread Block)内的线程才可以共享访问,而寄存器仅限于同一个线程内部才能访问。NVIDIA H100 中每个GPU 线程块在流式多处理器(Stream Multi-processor,SM)可以使用的共享存储容量仅有228KB,但是其速度非常快,远高于全局内存的访问速度。

在这里插入图片描述

通过自注意力机制的原理可知,在GPU 中进行计算时,传统的方法还需要引入两个中间矩阵S 和P 并存储到全局内存中。具体计算过程如下:

S = Q ×K, P = Softmax(S), O = P × V (2.27)

按照上述计算过程:

  • 需要首先从全局内存中读取矩阵Q 和K,并将计算好的矩阵S 再写入全局内存

  • 再从全局内存中获取矩阵S,计算Softmax 得到矩阵P,再写入全局内容

  • 读取矩阵P矩阵V ,计算得到矩阵O

这样的过程会极大占用显存的带宽。在自注意力机制中,计算速度比内存速度快得多,因此计算效率越来越多地受到全局内存访问的瓶颈。

FlashAttention就是通过利用GPU 硬件中的特殊设计,针对全局内存和共享存储的I/O 速度的不同,尽可能的避免HBM 中读取或写入注意力矩阵

FlashAttention 目标:是尽可能高效地使用SRAM 来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax 函数,并且后向传播中不能存储中间注意力矩阵。

标准Attention 算法中:

  • Softmax 计算按行进行,即在与V 做矩阵乘法之前,需要将Q、K 的各个分块完成一整行的计算。

  • 在得到Softmax 的结果后,再与矩阵V 分块做矩阵乘。

而在FlashAttention 中:

  • 将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行Softmax 计算。

自注意力算法的标准实现将计算过程中的矩阵S、P 写入全局内存中,而这些中间矩阵的大小与输入的序列长度有关且为二次型。

因此,FlashAttention 就提出了不使用中间注意力矩阵,通过存储归一化因子来减少全局内存的消耗。FlashAttention 算法并没有将S、P 整体写入全局内存,而是通过分块写入存储前向传递的Softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从全局内容中读取中间注意力矩阵的标准方法更快。由于大幅度减少了全局内存的访问量,即使重新计算导致FLOPs 增加,但其运行速度更快并且使用更少的内存。

PyTorch 2.0 中已经可以支持FlashAttention,使用“torch.backends.cuda.enable_flash_sdp()”启
用或者关闭FlashAttention 的使用。

  • 6
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

保持成长

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值