引言
今天介绍LLAMA2模型引入的关于注意力的改进——分组查询注意力(Grouped-query attention,GQA)。
Transformer中的多头注意力在解码阶段来说是一个性能瓶颈。多查询注意力通过共享单个key和value头,同时不减少query头来提升性能。多查询注意力可能导致质量下降和训练不稳定,因此常用的是分组查询注意力。
然后我们结合上篇文章探讨的旋转位置编码,将选择位置编码应用到分组查询注意力上。
多头注意力
我们先回顾以下原始多头注意力的实现。
import torch
from torch import nn