大模型-FlashAttention 算法分析

FlashAttention算法原理及优化分析

一、FlashAttention的概述

FlashAttention是一种IO感知精确注意力算法。通过感知显存读取/写入,FlashAttention的运行速度比PyTorch标准Attention快了2-4倍,所需内存也仅是其5%-20%。随着Transformer变得越来越大、越来越深,但它在长序列上仍然处理的很慢、且耗费内存。(自注意力时间和显存复杂度与序列长度成二次方),现有近似注意力方法,在试图通过去牺牲模型质量,以降低计算复杂度来解决该问题。但存在一定的局限性,即不能提升运行时的训练速度。

基于这样子背景,研究人员觉得应该让注意力算法具有IO感知,即考虑显存级间的读写,比如大但慢的HBM(High Bandwidth Memory)技术与小但快的SRAM。基于这样的背景,研究人员提出了FlashAttention,具体有两种加速技术:

Flash Attention主要运用了两种技术:分块和算子融合计算。

  1. 分块即将输入划分成块,并单独处理每个块的注意力;

  2. 算子融合,将多个算子融合成一个算子,将所有注意力操作融合到CUDA内核中。

在深度学习中,Attention Mechanism(注意力机制)用于在序列数据中学习不同位置的重要性权重,以便更好地捕捉序列中的关键信息。Attention Scores(注意力分数)是用来计算这些权重的指标。以下是关于计算Attention Scores的推导过程的文档:

二、FlashAttention v1具体原理的探究

先上经典的示意图

 2.1 注意力机制快速了解

在注意力机制中,Q(Query)、K(Key)、V(Value)和softmax函数是核心概念,用于计算注意力权重并生成输出表示。

Query(查询)、Key(键)、Value(值)

  • Query Q(查询):查询是用来寻找与之相关性的键的向量。在注意力机制中,查询向量Q通常是通过对输入数据进行线性变换得到的,用于衡量查询与键的相似度,决定了在生成输出表示时每个键的重要程度。

  • Key K(键):键是用来表示输入数据的向量,与查询一起计算注意力权重。键向量K也是通过对输入数据进行线性变换得到的。

  • Value V(值):值是用来生成最终输出表示的向量。值向量V同样是通过对输入数据进行线性变换得到的,根据注意力权重对值进行加权求和得到最终输出。

Softmax函数

Softmax函数是一个常用的激活函数,通常用于多分类问题中,将输入转化为概率分布。在注意力机制中,softmax函数被用来计算注意力权重,将注意力分数转化为概率分布,使得不同位置的输入在生成输出时得到不同的权重。

标准的Attention实现算法如下。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值