📖标题:Position-Aware Depth Decay Decoding (D3): Boosting Large Language Model Inference Efficiency
🌐来源:arXiv, 2503.08524
🌟摘要
🔸由于参数数量庞大,大型语言模型(LLM)的推理阶段是资源密集型的。与需要重新训练的传统模型压缩不同,最近的动态计算方法表明,推理并不需要所有组件,从而实现了无需训练的流水线。
🔸本文主要研究LLM生成的动态深度。提出了一种令牌位置感知跳层框架,在保持性能的同时,有效地节省了1.5倍的操作。我们首先观察到,稍后预测的令牌具有较低的困惑度,因此需要较少的计算。然后,我们提出了一种名为位置感知深度衰减解码(D3)的无训练算法,该算法利用幂律衰减函数L×(αi)来确定生成令牌Ti时要保留的层数。
🔸值得注意的是,在没有任何再训练的情况下,D3首次在广泛的生成任务中取得了成功。在具有70亿至700亿个参数的大型语言模型(即Llama)上的实验表明,与完整的推理流水线相比,D3可以实现平均1.5倍的加速,同时保持可比的性能,在GSM8K和BBH基准上几乎没有性能下降(<1%)。
🛎️文章简介
🔸研究问题:大语言模型(LLM)在生成任务中计算效率低下,是否可以动态减少每个token激活层数?
🔸主要贡献:论文提出了一种名为D3(Position-Aware Depth Decay Decoding)的框架,通过动态调整生成每个token时所需的计算资源,显著提高了推理速度,同时几乎不影响模型性能。
📝重点思路
🔸通过分析Llama模型中每个token的困惑度(PPL)变化,设计了一个幂律衰减函数,用于决定在生成token时需要保留的层数。
🔸D3框架不需要对模型进行再训练,只需通过网格搜索确定两个超参数(灵活层起始ID和衰减率α),便可在不同任务上实现自适应推理。
🔸解决了早期退出方法在批处理和KV缓存方面的挑战,提出了一种有效的层跳过策略,以提升生成任务中的计算效率。
🔎分析总结
🔸实验结果表明,D3在GSM8K和BBH基准测试上实现了平均1.5倍的推理速度提升,同时保持性能几乎没有下降(小于1%)。
🔸通过分析生成过程中不同token的计算需求,发现后期生成的token的困惑度较低,因此可以分配更少的计算资源,从而提高计算效率。
🔸在不同任务中,D3的超参数具有任务特定性,表明可以根据不同的任务需求进行调整以优化性能。
💡个人观点
论文的核心是动态调整每个token生成时需要用到的层,显著提升了推理速度并保持性能。