【论文解析】SpAtten: Efficient Sparse Attention Architecture with Cascade Token and Head Pruning

作者及发刊详情

@article{wang2021spatten,
title={SpAtten: Efficient Sparse Attention Architecture with Cascade Token and Head Pruning},
author={Wang, Hanrui and Zhang, Zhekai and Han, Song},
journal={HPCA},
year={2021}
}

摘要

正文

主要工作贡献

1)Cascade Token Pruning根据累计的token重要性score移除了不重要的token计算

减少了DRAM访问和计算

2)Cascade Head Pruning移除了不重要的head

减少了DRAM访问和计算

3)Progressive Quantization用小部分增加的计算换取较小的memory访问

基于attention可能性分布,改变了不同head和layer的位宽,减少了DRAM访问

4)专用的高并行top-k引擎有效支持on-the-fly的token和head选择,获得了O(n)时间复杂度

即时的token和head选择

实验评估

实验验证平台:SpinalHDL+Verilator

通过SpinalHDL编写并编译成RTL,通过Verilator进行仿真
HBM采用Ramulator模型[15]

选用模型: BERT 和 GPT-2

包括BERT-Base和 BERT-Large,以及 GPT-2-Small 和GPT-2-Medium

测试集

包括GLUE set [4], SQuAD [8], Wikitext-2 [9], Wikitext-103 [9], Pen Tree Bank [10] 和Google One-Billion Word [11] 在内的30个benchmark

PPA表现

在没有精度损失下减少了10倍的DRAM访问
获得了计算加速和能量节约。

与其他ASIC对比

在这里插入图片描述

设计思想

通过级联token剪枝、级联head剪枝以及量化减少计算和存储访问。
剪枝与传统方法不同,用于token和head,而不是weights;级联是指一旦一个token或者head被剪掉了,其会在后续所有层中被除去,所以一层只需要处理以前层中余下的tokens和heads。

Token和Head的级连剪枝

  1. 人类语言中有很多结构性的无异于的token,比如介词、副词等,这些是不重要的,对结果影响较小。
  2. attention使用很多head来抓取依赖,但部分是冗余的,可以被剪枝掉

通过attention_prob和attention_out计算head和token的重要性值:
在这里插入图片描述

局部V值剪枝

根据当前atterntion_prob的值,如果该值较小则修剪,softmax与V将不在计算

渐进量化

公式验证表明当attention_prob值满足 0 < = P < = 1 , 2 P ( 1 − P ) < 0.5 0<=P<=1, 2P(1-P)<0.5 0<=P<=1,2P(1P)<0.5,softmax后的整体量化误差会减小。
开始时采用最积极的量化策略,即只采用MSB计算,当计算后attention_prob分布被几个token占据,量化误差就很小;如果比较平坦,误差就大,将计入LSB重新计算attention_prob。

在这里插入图片描述

硬件设计

top-k根据token的重要性score排序获取最终的k值Ks,数据获取器计算Ks的地址并将其填充到32x16的crossbar中,两个crossbar连接着16通道的HBM,再通过一个反向的16x32的crossbar获取数据;Q和K通过一个举证向量乘模块计算,获取attention_score,随后通过softmax模块计算attention_prob,然后将结果发送到渐进量化模块,判断LSB是否需要;局部Value 剪枝的TopK引擎获取attention_prob并计算局部最重要的Vs,并将其编号发送到数据获取器;最终幸存的attention_prob乘以V,得到attention输出,
在计算完一个head后,head重要性score将会被累计,完成一层所有的head计算后,top-k模块将会剪去不重要的head,这些head在后续的层中不被计算。

在识别阶段,在cascade token剪枝中幸存下来的K和V将被传输到SRAM,并将在多个Q中重用。
在生成阶段,Q是一个单一向量,不存在KV的复用,无需再片上SRAM中存储。
在这里插入图片描述
关键路径6、7、8、10、11是全流水化的;
模块3中QKV三个分支需要同步,fetcher处理请求并把地址发送给FIFO
对于渐进量化模块9,如果LSB需要,将放弃当前attention_prob并重新计算,此时模块10和11将处于空闲状态,等待重新计算attention_prob 。

top-k引擎

支持token/head剪枝以及局部V值剪枝,需要找到阵列中top的k个要素
随机选择一个中枢将输入整列划分成2部分:比中枢大的部分和小的部分,并储存在两个FIFO中。
左右比较器整列分别保存比pivot小和大的元素,其他的将被置为零或者被零消除器(Zero Eliminator)消除为零。
快速选择器(Quick Select)将会迭代运行,知道找到第k个最大的元素。
第k个最大元素用于过滤输入阵列,过滤掉的阵列将被缓存到另一个FIFO中。
过滤掉的输出将被另一个零消除器处理从而或则最终的top-k元素,其可以很容易地用更多的比较器进行扩展。每个阵列中使用16个比较器从而使该模块不会成为整个流水线的瓶颈。
在这里插入图片描述

Zero Eliminator

在这里插入图片描述

Crossbar和Data Fetcher

剪枝将会导致随机的memory访问,crossbar处理地址,保持访存通道busy,增加带宽的使用率。

QKV数据获取器用于将多个read请求发送到所有的16个HBM通道,这些QKV数据在不同的通道中是交错的,通过32-16 crossbar来路由这些读请求到正确的通道上。

bitwidth converter

为了支持渐进量化,通过位宽转换器实现MSB和LSB的分离与聚合。
将DRAM来的4、8、12bit数据统一转换乘12bit,转换器包含很多MUX模块来选择输入数据正确的bit,以及一个位移器允许从非对齐的地址读取数据。

QK乘模块

计算QK的矩阵乘
每个周期,K矩阵的一行从K SRAM中加载出来通过一个乘法阵列乘以Q,后填充到加法树中。
加法树通过减少所有乘的记过来计算Attention Score。
本设计采用512个乘法器来充分使用DRAM带宽。
为了支持维度D低于512的QK计算,在每个周期获得512/D个attention score。
在这里插入图片描述

softmax和渐进量化模块

从乘模块输出的attention score定点数据需要先通过一个缩放因子(scaling factor)解量化。该因子包括了标准化因子 D \sqrt{D} D 的计算。
在softmax计算后将重新量化输入到渐进量化模块,判断LSB是否需要。
在这里插入图片描述

两个SRAM

地址FIFO:32个64深度的8B
数据FIFO:32个64深度的16B

参考文献

[15] Y. Kim et al., “Ramulator: A fast and extensible dram simulator,” IEEE Computer architecture letters, 2015.

提出了一个算法和硬件架构协同设计的加速器,将算法的优化用硬件实现,加速了Attention的计算,减少了DRAM的访问,但不加速FC层,全连接层仍由GPU、CPU或tensor代数加速器计算

spatten@MIT官网

该项目已开源,采用chisal语言编写
Github开源链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KGback

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值