摘要
提示是当前利用语言模型(LM)的多任务能力的主要方法,但提示占据了输入上下文窗口中宝贵的空间,并且解码时重新编码相同的提示会导致计算效率低下。微调和蒸馏的方法允许在不进行提示的情况下对LM进行专业化,但需要针对每个任务进行重新训练。为了完全避免这种问题,我们提出了gisting,该方法训练一个LM来将提示压缩成较小的“要点”token集合,可以重用以提高计算效率。可以通过受约束的注意力屏蔽来作为指令微调的一部分,以轻松训练Gist模型,从而鼓励对提示进行压缩。在解码器(LLaMA-7B)和 encoder-decoder (FLAN-T5-XXL) LM上,gisting最多可将提示压缩26倍,最多可减少40%的FLOPS,4.2%的wall time加速,节省存储,并使输出质量受到的影响最小。
1.介绍
考虑一个诸如ChatGPT之类的Transformer语言模型(LM)的提示:
ChatGPT每天有数百万用户进行询问,该提示会被自注意力机制一遍一遍的编码,其时间和空间复杂度是输入长度的二次方。缓存提示的transformer激活可以防止某些重新计算,但是随着缓存提示的数量的增加,该策略仍然会增加内存和存储成本。在大模型上,随着时间的推移,即使提示长度的少量减少也可能导致大量的计算,内存和存储的节省,同时还可以让用户将更多的内容输出到LM有限的上下文窗口中。
我们如何降低提示的成本?一种典型的方法是对模型进行微调或蒸馏,以与原始模型性能相似,同时没有提示,也许这是一种参数高效的适应方法。然而,这种方法的基本缺点是,它需要为每个新提示重新训练模型(Figure 1, middle)。
取而代之的是,我们提出了gisting(图1,底部),其将任意提示压缩成较小的虚拟token,这是一种前缀微调的方式。但是,前缀微调需要通过梯度下降来学习每个任务的前缀,gisting则采用元学习方法,仅给定zero-shot提示来预测gist前缀。这可以减少每个任务的前缀成本,从而能够在没有任何训练的情况泛化导未知指令。由于gist token比完整的提示要短得多,因此gisting可以对任意提示进行压缩,缓存和重用以提高计算效率。
在本文中,我们进一步提出了一种非常简单的方法来学习一个针对指令的gist模型,即:在提示后插入gist token,然后进行指令微调,同时使用修改后的注意力屏蔽机制,以防止在gist token之后的token看到gist token之前的token。这使模型可以同时学习提示的压缩和指令,而没有额外的训练费用。
在decoder-only(LLaMA-7B)和encoder-decoder(Flan-T5-XXL)LM上,gisting最多可将提示压缩26倍,最多可减少40%的FLOPS,4.2%的wall time加速,节省存储,并使输出质量受到的影响最小。
2.Gisting
我们首先在指令微调的背景下对gisting进行描述。我们有一个遵循指令的数据集
D
=
{
(
t
i
,
x
i
,
y
i
)
}
i
=
1
N
\mathcal D=\{(t_i,x_i,y_i)\}^N_{i=1}
D={(ti,xi,yi)}i=1N,其中
t
t
t是一个由自然语言提示编码的任务(例如,Translate this to French),
x
x
x是任务的(可选)输入(例如,The cat),
y
y
y是所需的输出(例如,Le chat)。给定一个(通常是预训练)LM,指令微调的目的是学习一个分布
p
L
M
(
y
∣
t
,
x
)
p_{LM}(y|t,x)
pLM(y∣t,x),通常是通过拼接
t
t
t和
x
x
x,然后让LM自回归预测
y
y
y。在推理时,我们可以通过将新的任务
t
t
t和输入
x
x
x带入模型,通过模型解码以获得其预测。
但是,这种拼接
t
t
t和
x
x
x的模式具有缺陷:基于Transformer的LM具有有限的上下文窗口,该窗口受网络结构或计算资源的限制。尤其后者使用自注意力所产生的问题。因此,长序列提示
t
t
t,尤其是那些重复使用的提示,在计算上是效率低下的。我们可以用哪些方式来降低此提示的成本?
一种简单选择是针对每一个特定任务
t
t
t来微调LM。也就是说,给定的
D
t
=
{
(
x
i
,
y
i
)
}
i
=
1
N
t
\mathcal D^t=\{(x_i,y_i)\}^{N^t}_{i=1}
Dt={(xi,yi)}i=1Nt,该数据集仅包含任务
t
t
t相关的输入/输出样例,我们可以学习一个专业的LM
p
L
M
t
(
y
∣
x
)
p^t_{LM}(y|x)
pLMt(y∣x),因为它不包含
t
t
t,因此会更快。一些更好的方法是参数高效的微调,例如prefix-/prompt-tuning或adapters,这能减少全参数微调的成本。然而,问题仍然存在,即我们必须为每个任务存储一个模型权重的子集,更重要的是,对于每个任务
t
t
t,我们必须为每个任务收集相应的输入/输出数据集
D
t
\mathcal D^t
Dt并重新训练模型。
gisting是一种不同的方法,可以同时解决(1)以
t
t
t为条件推理
p
L
M
p_{LM}
pLM的成本和(2)为每个
t
t
t学习一个新
p
L
M
t
p^t_{LM}
pLMt的训练成本。这个想法是要在微调期间学习
t
t
t的压缩版本:
G
(
t
)
G(t)
G(t),以使得基于
p
G
(
y
∣
G
(
t
)
,
x
)
p_G(y|G(t),x)
pG(y∣G(t),x)的推理速度要比
p
L
M
(
y
∣
t
,
x
)
p_{LM}(y|t,x)
pLM(y∣t,x)快。 在LM种,
G
(
t
)
G(t)
G(t)是一组“虚拟”gist token,其长度要比
t
t
t小,但在LM种仍然具有相似行为。然后可以将
G
(
t
)
G(t)
G(t)表示激活并进行缓存(例如,键值矩阵),同时重复使用以提高计算效率。至关重要的是,我们希望
G
G
G能泛化到未知任务:给定一个新的任务
t
t
t,我们可以预测并使用相应的gist
G
(
t
)
G(t)
G(t)而无需任何其他训练。