GQA是2023年发表的一篇paper提出的idea,目前用在了llama2、falcon等LLM上。paper一般都篇幅众多,老规矩,本文总结出最精华的部分:) 原文首发于我的公众号"AI不止算法",文章链接在此
动机
GQA的动机主打的是MQA(multi query attention)会导致quality degradation,我们不希望仅仅是推理快,而且还希望quality可以对标MHA,所以GQA带着这个使命诞生,可以很好的做到这个balance。MQA的动机主要在于key和value的数量是随着头数量成正比,那么尤其在decoder inference的过程中,本身就是一个memory bound的过程,这下更加memory bound了,带宽的压力山大,速度快不起来,所以呢,减少头的数量,减少kv cache的size,达到减小带宽的压力的目的,那么MQA推理速度势必更快。
概念
在19年的时候也有一篇paper提出了一个叫做MQA(multi query attention)的idea,GQA可以看作是MQA和MHA的中间或者一般化形态,当GQA里的Group=1的时候,此时为MHA,当GQA的Group=头的数量的时候,此时为MQA,图片非常直观,我就不废话了 😃
GQA和MQA带来的推理性能提升
性能提升主要来自于kv cache的size减小,那么kv cache占用的显存就变小,那么我们LLM serving可以处理的请求数量就更多,batchsize更大,吞吐量就变大。
同时这个收益也不是白来的,那就是需要做一个broadcast,见如下llama的src code,需要用torch.repeat_interleave对kv做一个broadcast,为什么要做呢?因为做之前的shape为(batchsize, kv head num, max seq len, head size),我们想要的shape为(batchsize, q head num, max k len, head size),为什么想要这个呢?因为我们要将q和k做batch gemm,batch gemm对shape的要求是AB矩阵的batch肯定要一样,所以需要把kv head num广播为q head num。
但是这个开销也是可以减小的,在flash attention的实现里面,把这个广播操作融合了进去,那么会极大的减小memory traffic,对带宽的压力非常小,那么kv cache的size减小带来的收益将会更显著。
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads # 扩展为与q一样的完整head
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
性能测试
秉着科研的严谨态度,最后贴几张性能测试结果的图片:
2.
最后,欢迎关注我的公众号“AI不止算法”