优化 | 块坐标下降法:助力单张3090全参数高效微调7B级大模型

研究背景

随着大模型在人工智能领域的崛起,其强大的功能在各个研究领域得到了广泛的挖掘和应用。大模型的微调训练(fine-tuning)是实现其在下游任务中发挥作用的关键步骤,因此,针对高效微调训练的优化算法研究,已经成为了学术界和工业界关注的焦点。全参数微调能够最大限度地发掘大模型在特定任务上的潜力,但这种方法往往需要耗费大量的GPU计算资源 (GPU RAM)。在资源受限的情况下,诸如LoRA等参数高效的微调算法显得尤为重要,成为了在计算资源受限的环境下的首选方案,但其与全参数Adam微调仍存在一定的性能差异。如何在有限的资源下实现接近全参数微调的性能,已成为大模型研究领域的热点。本文从优化算法设计的视角出发,针对此问题提出了算法——BAdam(Block coordinate method with Adam as an inner solver),在大模型的微调训练中实现资源与性能的最优平衡。

算法设计

块坐标优化(block coordinate optimization)是一种历史悠久、变体众多的优化算法设计策略。在每次迭代中,这种优化策略保持大部分优化参数在其最新的迭代值,(近似)求解剩余参数形成的低维度优化问题。由于算法每步迭代需要求解的是一个比原始问题维度低得多的优化问题,应用高效的近似求解算法于子问题可最终获得原始大规模优化问题的高效求解算法。块坐标类优化算法尤其适用于优化变量数巨大的大规模优化问题,而这一特性正是大模型微调训练的特征,以Llama 2-7B大模型为例,其微调训练所需训练集中的数据个数通常在10万以下的量级,而其待优化的参数量却高达70亿。

由于上述子问题依旧具有高度非凸的特性,BAdam应用神经网络训练中被广泛认可的Adam算法作为子问题的近似求解器。算法的总体设计如下图:

算法特性

实验效果

本文所有实验均在单张RTX3090-24GB GPU上实际实现。 本文通过实际的微调任务场景,在Alpaca-GPT4数据集上微调训练Llama 2-7B,来比较BAdam与目前主流的几个内存高效的微调算法的性能表现。下图展示了相同data pass下几种内存高效微调算法的训练损失,可以看出BAdam算法的优势;根据上一节的分析,BAdam算法在实际运行时间上会有更明显的优势:

通过MT-bench评估的下游任务表现显示了BAdam算法在使用更少计算时间的同时,利用全参微调带来的相比LoRA微调算法的优势:

此外,在SuperGLUE benchmark上的表现显示出BAdam具有接近全参数Adam微调的能力:

总结

本研究初步探索表明,块坐标下降类算法在当代大模型研究领域展现出较为广泛的应用潜力。该类算法在确保下游任务性能不受明显影响的同时,有效降低了对GPU内存资源的依赖,进而促进了大模型在低内存资源条件下的高效优化。

更多详细内容

Qijun Luo, Hengxu Yu, Xiao Li. "BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models", arXiv preprint, 2024, https://arxiv.org/abs/2404.02827.

Github project page (code): https://github.com/Ledzy/BAdam.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值