flash attention是一个用于加速模型训练推理的可选项,且仅适用于Turing、Ampere、Ada、Hopper架构的Nvidia GPU显卡(如H100、A100、RTX X090、T4)
1.首先检查一下GPU是否支持:FlashAttention
import torch
def supports_flash_attention(device_id: int):
"""Check if a GPU supports FlashAttention."""
major, minor = torch.cuda.get_device_capability(device_id)
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
return is_sm8x or is_sm90
print(supports_flash_attention(device_id)) #-> device_id 显卡号, 0 / 1 / 2 。。。
2.如果不支持,将模型文件夹中的config.json文件中的use_flash_attn改为false。
use_flash_attn参数名称可能会有些不同。
3.如果支持,将FlashAttention升级版本。
!pip install -U flash-attn