LoRA:参数高效微调的优雅之道
随着大语言模型(LLM)的规模日益增长,微调这些庞然大物以适应下游任务成为了一个计算资源与存储成本的双重挑战。传统的全参数微调(Full Fine-Tuning)虽然效果显著,但需要为每个任务保存一份完整的模型副本,这在实际部署中往往不切实际。为了解决这一问题,参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)方法应运而生,其中LoRA(Low-Rank Adaptation)以其简洁的设计和优异的性能,成为近年来研究的热点。本文将深入探讨LoRA的原理、实现细节及其在LLM微调中的独特优势,面向熟悉深度学习和LLM的高级研究者。
一、背景与动机
在深度学习中,预训练模型(如BERT、LLaMA、GPT系列)通过大规模无监督数据学习通用表示,随后通过微调适配特定任务。然而,当模型参数量达到数十亿甚至千亿级别时,全参数微调的成本迅速攀升。以一个100亿参数的模型为例,若使用FP16存储,每份微调模型需要约20GB空间,多任务场景下的存储需求令人望而却步。此外,全参数微调可能导致灾难性遗忘(Catastrophic Forgetting),影响模型的泛化能力。
LoRA的提出(Hu et al., 2021)正是为了应对这些挑战。它基于一个核心假设:任务特定的更新可以用低秩矩阵近似表示,从而大幅减少需要训练和存储的参数量。这种方法不仅计算高效,还能在保持性能的同时实现任务间的快速切换。
二、LoRA的数学原理
LoRA的核心思想是对预训练权重矩阵的更新进行低秩分解。假设一个预训练的权重矩阵为 ( W 0 ∈ R d × k W_0 \in \mathbb{R}^{d \times k} W0∈Rd×k )(例如Transformer中的注意力或前馈层权重),传统微调会直接更新整个 ( W 0 W_0 W0 )。LoRA则假设权重的更新 ( Δ W \Delta W ΔW) 是低秩的,即:
Δ W = A B \Delta W = A B ΔW=AB
其中,( A ∈ R d × r A \in \mathbb{R}^{d \times r} A∈Rd×r ),( B ∈ R r × k B \in \mathbb{R}^{r \times k} B∈Rr×k ),且 ( r ≪ min ( d , k ) r \ll \min(d, k) r≪min(d,k))。最终的前向计算变为:
W = W 0 + Δ W = W 0 + A B W = W_0 + \Delta W = W_0 + A B W=W0+ΔW=W0+AB
- ( r r r ) 是LoRA的秩(rank),通常取值较小(如8、16或32),控制可训练参数的数量。
- 在训练时,( W 0 W_0 W0 ) 保持冻结,仅优化 ( A A A ) 和 ( B B B )。
这种设计的理论依据来源于矩阵分解的特性:任务特定的更新往往集中在低维子空间中,而无需修改整个权重矩阵。研究表明,即使 ( r r r ) 远小于原始矩阵的秩,LoRA也能捕捉任务的关键特征。
三、实现细节与实践经验
在实践中,LoRA通常应用于Transformer模型的特定层,例如自注意力模块的 ( W q W_q Wq )(查询)、( W k W_k Wk )(键)、( W v W_v Wv )(值)或前馈网络的权重矩阵。以下是一些关键实现要点:
- 选择目标层:并非所有权重都需要LoRA更新。实验表明,对注意力层的 ( W q W_q Wq ) 和 ( W v W_v Wv ) 应用LoRA通常已足够,而对 ( W k W_k Wk ) 和输出投影的干预可能收益有限。
- 初始化策略:( A A A ) 通常初始化为高斯分布,( B B B ) 初始化为零,确保初始时 ( Δ W = 0 \Delta W = 0 ΔW=0),避免对预训练权重造成干扰。
- 超参数调节:( r r r ) 和LoRA的缩放因子 ( α \alpha α)(用于调整 ( Δ W \Delta W ΔW) 的幅度,计算为 ( α / r \alpha / r α/r))对性能影响显著。研究者需根据任务复杂度平衡 ( r r r ) 的选择。
- 集成与推理:训练完成后,可将 ( A A A ) 和 ( B B B ) 与 ( W 0 W_0 W0 ) 合并,推理时无需额外计算开销。
开源工具如Hugging Face的PEFT库已提供了LoRA的便捷实现,支持与PyTorch无缝集成。例如,对LLaMA模型应用LoRA的代码片段如下:
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16, # 秩
lora_alpha=32, # 缩放因子
target_modules=["q_proj", "v_proj"], # 目标层
lora_dropout=0.05 # dropout率
)
model = get_peft_model(base_model, config)
四、优势与局限性
优势
- 参数效率:相比全参数微调,LoRA的参数量减少了几个数量级。例如,一个100亿参数模型的全微调需要调整全部参数,而LoRA可能仅需训练0.01%的参数。
- 模块化设计:( A A A ) 和 ( B B B ) 可以独立存储和加载,支持多任务场景下的快速切换。
- 性能保留:大量实验表明,LoRA在下游任务上的性能与全微调接近,甚至在某些场景下更优。
局限性
- 表达能力受限:低秩假设可能无法完全捕捉复杂任务的高维更新需求。
- 超参数敏感性:( r r r ) 和 ( α \alpha α) 的选择依赖经验调整,缺乏理论指导。
- 适用范围:LoRA在Transformer架构中表现优异,但在其他架构(如CNN或RNN)上的效果尚未充分验证。
五、应用场景与研究进展
LoRA已被广泛应用于自然语言处理(NLP)任务,如文本分类、生成、翻译等。此外,它在多模态模型(如CLIP)和持续学习场景中也展现出潜力。近期研究进一步扩展了LoRA:
- AdaLoRA:动态调整秩以适应任务需求。
- LoRA+:结合其他PEFT方法(如Adapter)提升性能。
- 量化LoRA:与模型量化技术结合,进一步压缩存储需求。
六、未来方向
对于研究者而言,LoRA仍有多项值得探索的课题:
- 理论分析:低秩假设的适用边界是什么?能否从信息论或优化角度解释其成功?
- 自动化设计:开发自适应的秩选择策略,减少手动调参。
- 跨架构扩展:将LoRA推广至非Transformer模型,甚至硬件加速器。
- 与推理优化的协同:如何在边缘设备上结合LoRA实现高效部署?
七、结语
LoRA以其简洁而强大的设计,为LLM的高效微调提供了一条优雅的路径。它不仅降低了计算和存储的门槛,还为多任务学习和模型共享开辟了新可能。对于高级研究者而言,LoRA既是一个实用工具,也是一个值得深入挖掘的研究课题。希望本文能为你的实验和写作提供启发,欢迎在评论区交流你的LoRA实践经验!
Lora底层的实现
想要深入探讨 target_modules=["q_proj", "v_proj"]
在底层是如何实现的,特别是如何通过代码给Transformer的 (
W
q
W_q
Wq )(Query矩阵)和 (
W
v
W_v
Wv )(Value矩阵)添加LoRA的低秩更新矩阵 (
Δ
W
=
A
B
\Delta W = A B
ΔW=AB)。下面详细解析Hugging Face PEFT库中LoRA的底层实现细节,尤其是如何定位和修改这些目标模块。
一、LoRA底层实现的总体思路
在PEFT库中,LoraConfig
定义了LoRA的参数和目标模块(target_modules
),get_peft_model
函数则根据配置动态修改模型的权重层。底层实现的核心步骤包括:
- 定位目标模块:找到Transformer中对应的
"q_proj"
和"v_proj"
层(通常是线性层nn.Linear
)。 - 替换或增强权重:将原始的 ( W 0 W_0 W0 )(如 ( W q W_q Wq ) 或 ( W v W_v Wv ))替换为一个新的模块,支持 ( W 0 + A B W_0 + A B W0+AB ) 的计算。
- 前向计算:在新模块中实现带有LoRA更新的前向传播。
target_modules=["q_proj", "v_proj"]
的实现依赖于对模型结构的遍历和模块替换,具体由PEFT库的 LoraLayer
类和相关逻辑完成。
二、底层实现细节
1. 定位 target_modules
在Transformer模型(如LLaMA)中,"q_proj"
和 "v_proj"
是自注意力模块中线性层的命名约定。例如:
- 在
transformers.models.llama.modeling_llama.LlamaAttention
中:q_proj
是nn.Linear
,用于计算 ( Q = W q ⋅ X Q = W_q \cdot X Q=Wq⋅X )。v_proj
是nn.Linear
,用于计算 ( V = W v ⋅ X V = W_v \cdot X V=Wv⋅X )。
PEFT库通过递归遍历模型的层结构,匹配 target_modules
中的名称。代码大致逻辑如下(简化版):
def _find_and_replace_module(model, config):
for name, module in model.named_modules():
if any(target in name for target in config.target_modules):
# 找到目标模块(如 q_proj 或 v_proj 的 nn.Linear)
new_module = LoraLinear(module, config) # 用带LoRA的层替换
parent_name, child_name = name.rsplit(".", 1)
parent = model.get_submodule(parent_name)
setattr(parent, child_name, new_module)
# 示例调用
model = get_peft_model(base_model, lora_config)
named_modules()
返回模型中所有子模块的名称和实例。- 检查模块名称是否包含
"q_proj"
或"v_proj"
。 - 找到后,用一个新的
LoraLinear
层替换原始的nn.Linear
。
2. LoraLinear
的实现
LoraLinear
是PEFT库中实现LoRA的核心类。它继承自 torch.nn.Module
,并增强了原始线性层以支持低秩更新。以下是其简化实现(参考PEFT源码 peft.tuners.lora.layer.LoraLayer
和 peft.tuners.lora.layer.Linear
, 代码链接:https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py):
import torch
import torch.nn as nn
class LoraLinear(nn.Module):
def __init__(self, base_layer, config):
super().__init__()
self.base_layer = base_layer # 原始 nn.Linear 层
self.r = config.r # 秩
self.lora_alpha = config.lora_alpha # 缩放因子
self.lora_dropout = nn.Dropout(config.lora_dropout) # Dropout
self.scaling = self.lora_alpha / self.r # 缩放系数
# 获取原始权重维度
in_features, out_features = base_layer.weight.shape[1], base_layer.weight.shape[0]
# 初始化 LoRA 矩阵 A 和 B
self.lora_A = nn.Parameter(torch.randn(in_features, self.r)) # [d, r]
self.lora_B = nn.Parameter(torch.zeros(self.r, out_features)) # [r, k]
# 冻结原始权重
self.base_layer.weight.requires_grad = False
if hasattr(self.base_layer, "bias") and config.bias == "none":
self.base_layer.bias.requires_grad = False
def forward(self, x):
# 原始线性变换
base_output = self.base_layer(x) # W_0 * x
# LoRA 更新部分
lora_output = self.lora_dropout(x) # 应用 dropout
lora_output = torch.matmul(lora_output, self.lora_A) # x * A, [batch, r]
lora_output = torch.matmul(lora_output, self.lora_B) # (x * A) * B, [batch, out_features]
lora_output = lora_output * self.scaling # 应用缩放
# 合并输出
return base_output + lora_output
# 示例:替换 q_proj
base_layer = nn.Linear(4096, 4096) # 假设 LLaMA 的隐藏维度为 4096
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj"], lora_dropout=0.05)
lora_layer = LoraLinear(base_layer, lora_config)
关键点:
- 权重冻结:原始的 (
W
0
W_0
W0 )(
base_layer.weight
)被冻结,仅 ( A A A ) 和 ( B B B ) 可训练。 - 并行矩阵:(
Δ
W
=
A
B
\Delta W = A B
ΔW=AB) 不显式计算为完整矩阵,而是通过两次矩阵乘法(
x * A * B
)实现,避免存储 ( d × k d \times k d×k ) 的矩阵。 - 缩放:
scaling
调整LoRA更新的幅度,确保数值稳定性。
3. 前向传播的数学过程
假设输入 ( x ∈ R b a t c h × d x \in \mathbb{R}^{batch \times d} x∈Rbatch×d ):
- 原始输出:( y 0 = W 0 ⋅ x y_0 = W_0 \cdot x y0=W0⋅x),( W 0 ∈ R k × d W_0 \in \mathbb{R}^{k \times d} W0∈Rk×d )。
- LoRA更新:
- ( h = x ⋅ A h = x \cdot A h=x⋅A ),( A ∈ R d × r A \in \mathbb{R}^{d \times r} A∈Rd×r ),( h ∈ R b a t c h × r h \in \mathbb{R}^{batch \times r} h∈Rbatch×r );
- ( Δ y = h ⋅ B \Delta y = h \cdot B Δy=h⋅B ),( B ∈ R r × k B \in \mathbb{R}^{r \times k} B∈Rr×k ),( Δ y ∈ R b a t c h × k \Delta y \in \mathbb{R}^{batch \times k} Δy∈Rbatch×k );
- 缩放后:( Δ y = Δ y ⋅ ( α / r ) \Delta y = \Delta y \cdot (\alpha / r) Δy=Δy⋅(α/r) )。
- 最终输出:( y = y 0 + Δ y y = y_0 + \Delta y y=y0+Δy )。
这种分解避免了直接构造 ( Δ W \Delta W ΔW),计算复杂度从 ( O ( d k ) O(dk) O(dk) ) 降至 ( O ( d r + r k ) O(dr + rk) O(dr+rk) ),( r ≪ d , k r \ll d, k r≪d,k ) 时效率显著提升。
三、如何给 ( Q Q Q ) 和 ( V V V ) 矩阵加上“并行矩阵”?
这里的“并行矩阵”实际上是逻辑上的描述,底层并未显式构造 ( Δ W \Delta W ΔW):
- ( A A A ) 和 ( B B B ) 是两个独立的参数矩阵,分别与输入和输出维度绑定。
- 通过矩阵乘法的分解(
x * A * B
),LoRA动态生成更新量,等效于在 ( W q W_q Wq ) 和 ( W v W_v Wv ) 上叠加一个低秩调整。 - 在模型结构上,
LoraLinear
接管了原始线性层的计算逻辑,保留 ( W 0 W_0 W0 ) 的同时并行计算 ( Δ W \Delta W ΔW) 的贡献。
四、完整性验证
以下是验证代码的简单例子:
# 测试 LoraLinear
x = torch.randn(2, 4096) # batch_size=2, hidden_size=4096
output = lora_layer(x)
print(output.shape) # torch.Size([2, 4096])
# 检查可训练参数
for name, param in lora_layer.named_parameters():
if param.requires_grad:
print(f"{name}: {param.shape}")
# 输出:
# lora_A: torch.Size([4096, 16])
# lora_B: torch.Size([16, 4096])
可以看到,仅 ( A A A ) 和 ( B B B ) 是可训练的,参数量远小于原始的 ( 4096 × 4096 4096 \times 4096 4096×4096 )。
五、总结
target_modules=["q_proj", "v_proj"]
的底层实现依赖:
- 模块定位:通过名称匹配找到 ( W q W_q Wq ) 和 ( W v W_v Wv ) 的线性层。
- 层替换:用
LoraLinear
替换原始nn.Linear
,并嵌入 ( A A A ) 和 ( B B B )。 - 动态计算:前向传播中通过 ( x ⋅ A ⋅ B x \cdot A \cdot B x⋅A⋅B ) 实现低秩更新,而非显式构造并行矩阵。
这种设计既高效又灵活,是LoRA成功的关键。
低秩矩阵 ( A A A ) 和 ( B B B ) 的初始化策略
分析 LoRA 中低秩矩阵 ( A A A ) 和 ( B B B ) 的初始化策略(( A A A ) 使用随机高斯初始化,( B B B ) 初始化为零),其对训练稳定性与收敛速度的潜在影响,以及为何选择这样的设计,同时探讨其他可能的初始化方法(如基于 SVD 分解)。
一、当前初始化策略的分析
在 peft.tuners.lora.Linear.__init__
中,LoRA 的低秩矩阵初始化如下(简化版):
self.lora_A = nn.Parameter(torch.randn(in_features, r) * 0.02) # 高斯分布,标准差较小
self.lora_B = nn.Parameter(torch.zeros(r, out_features)) # 全零初始化
1. 为什么 ( A A A ) 是高斯分布,( B B B ) 是零初始化?
-
初始 ( Δ W = 0 \Delta W = 0 ΔW=0):
LoRA 的核心目标是保持预训练模型的原始行为不变,即在微调开始时,模型输出应尽量接近原始权重 ( W 0 W_0 W0 ) 的输出。通过将 ( B B B ) 初始化为零,确保 ( Δ W = A B = 0 \Delta W = A B = 0 ΔW=AB=0)(无论 ( A A A ) 的值如何),从而避免对 ( W 0 W_0 W0 ) 的立即干扰。这种“零起点”设计让微调从预训练状态平滑过渡,便于任务适配。 -
为什么 ( A A A ) 用高斯分布而非零:
如果 ( A A A ) 和 ( B B B ) 都初始化为零,则 ( Δ W \Delta W ΔW) 在整个训练过程中将始终为零(因为梯度更新依赖初始值,零初始会导致梯度为零)。为了让 LoRA 能够学习任务特定的更新,必须至少有一个矩阵(( A A A ) 或 ( B B B ))具有非零初始值。选择 ( A A A ) 为高斯分布而 ( B B B ) 为零,是一种折中:- ( A A A ) 的随机性为后续更新提供了多样性,确保模型能探索不同的低秩子空间。
- 高斯分布(通常标准差较小,如 0.02)避免初始值过大,防止训练早期的不稳定(如梯度爆炸)。
-
高斯分布的具体选择:
高斯分布是深度学习中常见的权重初始化方式(如 Xavier 或 He 初始化),因为它能提供适度的随机性,帮助网络跳出对称性陷阱。对于 ( A A A ),其维度是 ( [ i n f e a t u r e s , r ] [in_features, r] [infeatures,r] ),高斯初始化确保每一列(对应秩 ( r r r ) 的维度)有一定独立性,为后续优化提供良好的起点。
2. 对训练稳定性的影响
- 优点:
- 平滑过渡:( Δ W = 0 \Delta W = 0 ΔW=0) 保证了初始阶段模型行为与预训练一致,避免了因随机初始化导致的输出剧烈变化,增强了训练稳定性。
- 梯度控制:( B B B ) 为零,初始梯度仅通过 ( A A A ) 传播,配合小的标准差(如 0.02),降低了梯度爆炸的风险。
- 潜在问题:
- 初始更新缓慢:由于 ( B B B ) 为零,早期训练中 ( Δ W \Delta W ΔW) 的变化完全依赖 ( B B B ) 的梯度更新,而 ( B B B ) 的初始值为零可能导致收敛起步较慢,尤其在任务复杂度较高时。
- 依赖优化器:这种初始化对优化器的动量(如 Adam 的指数移动平均)依赖较大,若学习率设置不当,可能延迟有效收敛。
3. 对收敛速度的影响
- 优点:
- 避免过拟合早期干扰:初始零更新让模型逐步适应任务,避免过早偏离预训练知识,有助于长期收敛到更好的解。
- 潜在问题:
- 起步延迟:相比 ( A A A ) 和 ( B B B ) 都非零的初始化,当前策略可能需要更多步迭代才能让 ( Δ W \Delta W ΔW) 达到显著幅度,尤其当 ( r r r ) 较小时(低秩限制了表达能力)。
- 探索不足:( A A A ) 的高斯初始化虽然引入随机性,但标准差较小,可能限制模型在低秩空间中的初始探索范围。
二、其他初始化方法的探讨(如基于 SVD 分解)
1. 基于 SVD 的初始化
SVD(奇异值分解)可以将原始权重 ( $W_0 ) 分解为 ( W 0 = U Σ V T W_0 = U \Sigma V^T W0=UΣVT ),其中 ( U U U ) 和 ( V T V^T VT ) 是正交矩阵,( Σ \Sigma Σ) 是奇异值矩阵。可以用前 ( r r r ) 个奇异值和对应向量初始化 ( A A A ) 和 ( B B B ):
- ( A = U [ : , : r ] ⋅ Σ [ : r ] A = U[:, :r] \cdot \sqrt{\Sigma[:r]} A=U[:,:r]⋅Σ[:r] )
- ( B = Σ [ : r ] ⋅ V T [ : r , : ] B = \sqrt{\Sigma[:r]} \cdot V^T[:r, :] B=Σ[:r]⋅VT[:r,:] )
优点:
- 基于预训练知识:这种初始化直接从 ( W 0 W_0 W0 ) 的低秩近似开始,可能加速任务适配,尤其是当任务与预训练数据高度相关时。
- 表达能力:初始 ( Δ W \Delta W ΔW) 已接近 ( W 0 W_0 W0 ) 的主成分,可能比随机初始化更快捕捉关键特征。
缺点:
- 初始干扰:( Δ W ≠ 0 \Delta W \neq 0 ΔW=0),可能破坏预训练模型的稳定性。
- 计算开销:SVD 分解对大矩阵(如 ( 4096 × 4096 4096 \times 4096 4096×4096 ))计算成本高,不适合实时初始化。
- 任务无关性:预训练权重的低秩分解未必与下游任务相关,可能引入噪声。
2. 其他可能方法
- Xavier 初始化:对 ( A A A ) 和 ( B B B ) 都使用 Xavier 初始化(根据输入输出维度调整方差),可能加速收敛,但初始 ( Δ W ≠ 0 \Delta W \neq 0 ΔW=0) 会增加不稳定性。
- 正交初始化:对 (
A
A
A ) 或 (
B
B
B ) 使用正交矩阵初始化(如
torch.nn.init.orthogonal_
),可能改善梯度传播,但仍需解决初始零的问题。 - 动态初始化:根据任务数据的前几步梯度动态调整 ( A A A ) 和 ( B B B ) 的初始值,但这需要额外的预处理步骤。
三、改进建议与源码探讨
基于上述分析,如果你想在 PEFT 源码中优化初始化策略,可以尝试以下方向:
-
自适应标准差:
修改peft.tuners.lora.Linear.__init__
,让 ( A A A ) 的高斯标准差根据 ( i n _ f e a t u r e s in\_features in_features ) 或 ( r r r ) 自适应:std = 1.0 / math.sqrt(in_features) # 类似 He 初始化 self.lora_A = nn.Parameter(torch.randn(in_features, r) * std)
这可能提升初始探索能力。
-
部分非零 ( B B B ):
让 ( B B B ) 的部分元素(例如前几行)非零,加速早期收敛,同时控制初始扰动:self.lora_B = nn.Parameter(torch.zeros(r, out_features)) self.lora_B.data[:r//2, :] = torch.randn(r//2, out_features) * 0.01
-
SVD 预热:
在训练前对 ( W 0 W_0 W0 ) 做一次 SVD,缓存结果用于初始化(需扩展LoraConfig
支持预计算)。
四、总结
- 当前策略(( A A A ) 高斯,( B B B ) 零)通过初始零更新保证稳定性,但可能牺牲早期收敛速度。
- 高斯分布为 ( A A A ) 提供随机性,确保学习能力;( B B B ) 为零避免初始干扰。
- SVD 初始化理论上有潜力,但需解决计算成本和任务适配问题。
重新审视两种初始化场景
1. 当前默认初始化(( A A A ) 高斯,( B B B ) 零)
- ( A A A ):随机高斯初始化,例如 ( A i j ∼ N ( 0 , 0.0 2 2 ) A_{ij} \sim \mathcal{N}(0, 0.02^2) Aij∼N(0,0.022) )。
- ( B B B ):全零矩阵。
- 结果:( Δ W = A B = 0 \Delta W = A B = 0 ΔW=AB=0)。
- 影响:
- 初始时,模型输出完全由 ( W 0 W_0 W0 ) 决定,没有任何扰动。
- 训练开始后,( Δ W \Delta W ΔW) 的变化依赖 ( B B B ) 的梯度更新,逐步偏离零。
2. 反过来(( B B B ) 高斯,( A A A ) 零)
- ( B B B ):随机高斯初始化,例如 ( B i j ∼ N ( 0 , 0.0 2 2 ) B_{ij} \sim \mathcal{N}(0, 0.02^2) Bij∼N(0,0.022) )。
- ( A A A ):全零矩阵。
- 结果:( Δ W = A B = 0 \Delta W = A B = 0 ΔW=AB=0)。
- 影响:
- 初始时,( Δ W \Delta W ΔW) 仍然为零,模型输出依然由 ( W 0 W_0 W0 ) 决定,与默认初始化一致。
- 训练开始后,( Δ W \Delta W ΔW) 的变化依赖 ( A A A ) 的梯度更新。
两种初始化的区别
虽然两种初始化(( A A A ) 高斯 ( B B B ) 零 vs. ( B B B ) 高斯 ( A A A ) 零)在初始时都导致 ( Δ W = 0 \Delta W = 0 ΔW=0),但它们的区别体现在训练过程中的梯度传播和参数更新上:
1. 默认初始化(( A A A ) 高斯,( B B B ) 零)
- 前向传播:
y = W 0 x + ( A B ) x = W 0 x ( 因 B = 0 ) y = W_0 x + (A B) x = W_0 x \quad (\text{因 } B = 0) y=W0x+(AB)x=W0x(因 B=0) - 反向传播:
- 假设损失为 (
L
L
L ),对 (
B
B
B ) 的梯度:
∂ L ∂ B = A T ⋅ ∂ L ∂ y \frac{\partial L}{\partial B} = A^T \cdot \frac{\partial L}{\partial y} ∂B∂L=AT⋅∂y∂L
因为 ( A A A ) 非零,( B B B ) 会获得非零梯度,开始更新。 - 对 (
A
A
A ) 的梯度:
∂ L ∂ A = x T ⋅ ∂ L ∂ y ⋅ B T \frac{\partial L}{\partial A} = x^T \cdot \frac{\partial L}{\partial y} \cdot B^T ∂A∂L=xT⋅∂y∂L⋅BT
因为 ( B = 0 B = 0 B=0 ),初始时 ( A A A ) 的梯度为零,更新依赖 ( B B B ) 先动起来。
- 假设损失为 (
L
L
L ),对 (
B
B
B ) 的梯度:
- 动态:( B B B ) 先更新,带动 ( A A A ) 逐步调整,( Δ W \Delta W ΔW) 逐渐偏离零。
2. 反向初始化(( B B B ) 高斯,( A A A ) 零)
- 前向传播:
y = W 0 x + ( A B ) x = W 0 x ( 因 A = 0 ) y = W_0 x + (A B) x = W_0 x \quad (\text{因 } A = 0) y=W0x+(AB)x=W0x(因 A=0) - 反向传播:
- 对 (
A
A
A ) 的梯度:
∂ L ∂ A = x T ⋅ ∂ L ∂ y ⋅ B T \frac{\partial L}{\partial A} = x^T \cdot \frac{\partial L}{\partial y} \cdot B^T ∂A∂L=xT⋅∂y∂L⋅BT
因为 ( B B B) 非零,( A A A ) 会立即获得非零梯度,开始更新。 - 对 (
B
B
B ) 的梯度:
∂ L ∂ B = A T ⋅ ∂ L ∂ y \frac{\partial L}{\partial B} = A^T \cdot \frac{\partial L}{\partial y} ∂B∂L=AT⋅∂y∂L
因为 ( A = 0 A = 0 A=0 ),初始时 ( B B B ) 的梯度为零,更新依赖 ( A A A ) 先动起来。
- 对 (
A
A
A ) 的梯度:
- 动态:( A A A ) 先更新,带动 ( B B B ) 调整,( Δ W \Delta W ΔW) 逐渐偏离零。
关键区别:
- 更新顺序:
- ( A A A ) 高斯 ( B B B ) 零:( B B B ) 先获得梯度,驱动 ( Δ W \Delta W ΔW) 变化。
- ( B B B ) 高斯 ( A A A ) 零:( A A A ) 先获得梯度,驱动 ( Δ W \Delta W ΔW) 变化。
- 梯度传播路径:
- ( A A A ) 在输入侧(靠近 ( x x x )),其更新直接受输入分布影响。
- ( B B B ) 在输出侧(靠近 ( y y y )),其更新更直接受损失影响。
- 实际影响:
- ( A A A ) 高斯 ( B B B ) 零更常见,因为 ( A A A ) 的随机性为初始探索提供了基础,而 ( B B B ) 的零初始化确保稳定性。
- ( B B B ) 高斯 ( A A A ) 零可能让更新更依赖输出端的梯度信号,可能在某些任务中收敛路径略有不同,但初始稳定性无差别。
对A和B梯度的计算推导
一、前向传播的正确定义
LoRA 的前向传播为:
y
=
W
0
x
+
(
A
B
)
x
y = W_0 x + (A B) x
y=W0x+(AB)x
其中:
- ( x ∈ R b a t c h × d x \in \mathbb{R}^{batch \times d} x∈Rbatch×d ):输入张量(批量优先)。
- ( W 0 ∈ R k × d W_0 \in \mathbb{R}^{k \times d} W0∈Rk×d ):冻结的预训练权重。
- ( A ∈ R d × r A \in \mathbb{R}^{d \times r} A∈Rd×r ):LoRA 的第一个低秩矩阵。
- ( B ∈ R r × k B \in \mathbb{R}^{r \times k} B∈Rr×k ):LoRA 的第二个低秩矩阵。
- ( y ∈ R b a t c h × k y \in \mathbb{R}^{batch \times k} y∈Rbatch×k ):输出张量。
LoRA 部分的计算可以分解为:
y
lora
=
(
A
B
)
x
=
(
x
A
)
B
y_{\text{lora}} = (A B) x = (x A) B
ylora=(AB)x=(xA)B
- 先计算 (
h
=
x
A
h = x A
h=xA ):
- ( x ∈ R b a t c h × d x \in \mathbb{R}^{batch \times d} x∈Rbatch×d ),
- ( A ∈ R d × r A \in \mathbb{R}^{d \times r} A∈Rd×r ),
- ( h ∈ R b a t c h × r h \in \mathbb{R}^{batch \times r} h∈Rbatch×r )(维度匹配)。
- 再计算 (
y
lora
=
h
B
y_{\text{lora}} = h B
ylora=hB ):
- ( h ∈ R b a t c h × r h \in \mathbb{R}^{batch \times r} h∈Rbatch×r ),
- ( B ∈ R r × k B \in \mathbb{R}^{r \times k} B∈Rr×k ),
- ( y lora ∈ R b a t c h × k y_{\text{lora}} \in \mathbb{R}^{batch \times k} ylora∈Rbatch×k )(维度匹配)。
这与 PEFT 源码中的实现一致(lora_output = torch.matmul(torch.matmul(x, lora_A), lora_B)
)。
二、梯度推导
假设损失函数为标量 ( L L L ),依赖于 ( y y y )。我们需要计算 ( ∂ L ∂ A \frac{\partial L}{\partial A} ∂A∂L) 和 ( ∂ L ∂ B \frac{\partial L}{\partial B} ∂B∂L),已知 ( ∂ L ∂ y ∈ R b a t c h × k \frac{\partial L}{\partial y} \in \mathbb{R}^{batch \times k} ∂y∂L∈Rbatch×k) 是反向传播的上游梯度。
1. ( ∂ L ∂ A \frac{\partial L}{\partial A} ∂A∂L)
根据链式法则:
∂
L
∂
A
=
∂
L
∂
y
⋅
∂
y
∂
A
\frac{\partial L}{\partial A} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial A}
∂A∂L=∂y∂L⋅∂A∂y
(1) 计算 ( ∂ y ∂ A \frac{\partial y}{\partial A} ∂A∂y)
- ( y = W 0 x + h B y = W_0 x + h B y=W0x+hB ),
- ( h = x A h = x A h=xA )。
对 (
A
A
A ) 求偏导:
∂
y
∂
A
=
∂
(
h
B
)
∂
A
=
∂
h
∂
A
⋅
B
\frac{\partial y}{\partial A} = \frac{\partial (h B)}{\partial A} = \frac{\partial h}{\partial A} \cdot B
∂A∂y=∂A∂(hB)=∂A∂h⋅B
- ( h = x A h = x A h=xA ),
- 对 (
A
p
q
A_{pq}
Apq )(第 (
p
p
p ) 行,第 (
q
q
q ) 列)求偏导:
h i j = ∑ m = 1 d x i m A m j h_{ij} = \sum_{m=1}^{d} x_{im} A_{mj} hij=m=1∑dximAmj
∂ h i j ∂ A p q = { x i p if j = q 0 otherwise \frac{\partial h_{ij}}{\partial A_{pq}} = \begin{cases} x_{ip} & \text{if } j = q \\ 0 & \text{otherwise} \end{cases} ∂Apq∂hij={xip0if j=qotherwise
矩阵形式:
∂ h ∂ A = x T \frac{\partial h}{\partial A} = x^T ∂A∂h=xT
(这里 ( x T ∈ R d × b a t c h x^T \in \mathbb{R}^{d \times batch} xT∈Rd×batch ) 是输入的转置,实际是张量形式)。 - 然后:
y = h B y = h B y=hB
∂ y ∂ h = B \frac{\partial y}{\partial h} = B ∂h∂y=B
(( B ∈ R r × k B \in \mathbb{R}^{r \times k} B∈Rr×k ))。
综合:
∂
y
∂
A
=
∂
h
∂
A
⋅
B
=
x
T
⋅
B
\frac{\partial y}{\partial A} = \frac{\partial h}{\partial A} \cdot B = x^T \cdot B
∂A∂y=∂A∂h⋅B=xT⋅B
(2) 合并梯度
∂ L ∂ A = ∂ L ∂ y ⋅ ∂ y ∂ A \frac{\partial L}{\partial A} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial A} ∂A∂L=∂y∂L⋅∂A∂y
- ( ∂ L ∂ y ∈ R b a t c h × k \frac{\partial L}{\partial y} \in \mathbb{R}^{batch \times k} ∂y∂L∈Rbatch×k),
- ( x T ∈ R d × b a t c h x^T \in \mathbb{R}^{d \times batch} xT∈Rd×batch ),
- ( B ∈ R r × k B \in \mathbb{R}^{r \times k} B∈Rr×k )(但实际需要 ( B T ∈ R k × r B^T \in \mathbb{R}^{k \times r} BT∈Rk×r ) 与链式法则匹配)。
正确形式为:
∂
L
∂
A
=
x
T
⋅
∂
L
∂
y
⋅
B
T
\frac{\partial L}{\partial A} = x^T \cdot \frac{\partial L}{\partial y} \cdot B^T
∂A∂L=xT⋅∂y∂L⋅BT
- ( x T ⋅ ∂ L ∂ y ∈ R d × k x^T \cdot \frac{\partial L}{\partial y} \in \mathbb{R}^{d \times k} xT⋅∂y∂L∈Rd×k ),
- ( B T ∈ R k × r B^T \in \mathbb{R}^{k \times r} BT∈Rk×r ),
- ( ∂ L ∂ A ∈ R d × r \frac{\partial L}{\partial A} \in \mathbb{R}^{d \times r} ∂A∂L∈Rd×r)(与 ( A A A ) 形状一致)。
2. ( ∂ L ∂ B \frac{\partial L}{\partial B} ∂B∂L)
∂ L ∂ B = ∂ L ∂ y ⋅ ∂ y ∂ B \frac{\partial L}{\partial B} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial B} ∂B∂L=∂y∂L⋅∂B∂y
(1) 计算 ( ∂ y ∂ B \frac{\partial y}{\partial B} ∂B∂y)
- ( y = h B y = h B y=hB ),
- ( h = x A h = x A h=xA )。
对 (
B
p
q
B_{pq}
Bpq ) 求偏导:
y
i
j
=
∑
m
=
1
r
h
i
m
B
m
j
y_{ij} = \sum_{m=1}^{r} h_{im} B_{mj}
yij=m=1∑rhimBmj
∂
y
i
j
∂
B
p
q
=
{
h
i
p
if
j
=
q
0
otherwise
\frac{\partial y_{ij}}{\partial B_{pq}} = \begin{cases} h_{ip} & \text{if } j = q \\ 0 & \text{otherwise} \end{cases}
∂Bpq∂yij={hip0if j=qotherwise
矩阵形式:
∂
y
∂
B
=
h
T
\frac{\partial y}{\partial B} = h^T
∂B∂y=hT
((
h
T
∈
R
r
×
b
a
t
c
h
h^T \in \mathbb{R}^{r \times batch}
hT∈Rr×batch ))。
(2) 合并梯度
∂ L ∂ B = ∂ L ∂ y ⋅ ∂ y ∂ B \frac{\partial L}{\partial B} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial B} ∂B∂L=∂y∂L⋅∂B∂y
- ( ∂ L ∂ y ∈ R b a t c h × k \frac{\partial L}{\partial y} \in \mathbb{R}^{batch \times k} ∂y∂L∈Rbatch×k),
- ( h = x A ∈ R b a t c h × r h = x A \in \mathbb{R}^{batch \times r} h=xA∈Rbatch×r ),
- ( h T ∈ R r × b a t c h h^T \in \mathbb{R}^{r \times batch} hT∈Rr×batch )。
正确形式:
∂
L
∂
B
=
h
T
⋅
∂
L
∂
y
\frac{\partial L}{\partial B} = h^T \cdot \frac{\partial L}{\partial y}
∂B∂L=hT⋅∂y∂L
代入 (
h
=
x
A
h = x A
h=xA ):
∂
L
∂
B
=
(
x
A
)
T
⋅
∂
L
∂
y
=
A
T
x
T
⋅
∂
L
∂
y
\frac{\partial L}{\partial B} = (x A)^T \cdot \frac{\partial L}{\partial y} = A^T x^T \cdot \frac{\partial L}{\partial y}
∂B∂L=(xA)T⋅∂y∂L=ATxT⋅∂y∂L
但 PyTorch 中批量维度通常隐式处理,最终简化为:
∂
L
∂
B
=
A
T
⋅
∂
L
∂
y
\frac{\partial L}{\partial B} = A^T \cdot \frac{\partial L}{\partial y}
∂B∂L=AT⋅∂y∂L
- ( A T ∈ R r × d A^T \in \mathbb{R}^{r \times d} AT∈Rr×d ),
- ( ∂ L ∂ y ∈ R b a t c h × k \frac{\partial L}{\partial y} \in \mathbb{R}^{batch \times k} ∂y∂L∈Rbatch×k),
- ( ∂ L ∂ B ∈ R r × k \frac{\partial L}{\partial B} \in \mathbb{R}^{r \times k} ∂B∂L∈Rr×k)(与 ( B B B ) 形状一致)。
三、结合初始化场景分析
1. 默认初始化(( A A A ) 高斯,( B B B ) 零)
- 前向传播:
y = W 0 x + ( A B ) x = W 0 x ( 因 B = 0 ) y = W_0 x + (A B) x = W_0 x \quad (\text{因 } B = 0) y=W0x+(AB)x=W0x(因 B=0) - 反向传播:
- 对 (
A
A
A ) 的梯度:
∂ L ∂ A = x T ⋅ ∂ L ∂ y ⋅ B T \frac{\partial L}{\partial A} = x^T \cdot \frac{\partial L}{\partial y} \cdot B^T ∂A∂L=xT⋅∂y∂L⋅BT- ( B = 0 B = 0 B=0 ),( B T = 0 B^T = 0 BT=0 ),
- ( ∂ L ∂ A = x T ⋅ ∂ L ∂ y ⋅ 0 = 0 \frac{\partial L}{\partial A} = x^T \cdot \frac{\partial L}{\partial y} \cdot 0 = 0 ∂A∂L=xT⋅∂y∂L⋅0=0),
- 初始时 ( A A A ) 的梯度为零,不更新。
- 对 (
B
B
B ) 的梯度:
∂ L ∂ B = A T ⋅ ∂ L ∂ y \frac{\partial L}{\partial B} = A^T \cdot \frac{\partial L}{\partial y} ∂B∂L=AT⋅∂y∂L- ( A A A ) 非零(高斯初始化),
- ( ∂ L ∂ y \frac{\partial L}{\partial y} ∂y∂L) 通常非零(取决于损失),
- ( ∂ L ∂ B ≠ 0 \frac{\partial L}{\partial B} \neq 0 ∂B∂L=0),( B B B ) 开始更新。
- 对 (
A
A
A ) 的梯度:
- 动态:
- ( B B B ) 先获得非零梯度并更新,脱离全零状态。
- ( B B B ) 更新后,( ∂ L ∂ A \frac{\partial L}{\partial A} ∂A∂L) 变得非零,带动 ( A A A ) 调整。
- ( Δ W = A B \Delta W = A B ΔW=AB) 逐渐偏离零。
2. 反向初始化(( B B B ) 高斯,( A A A ) 零)
- 前向传播:
y = W 0 x + ( A B ) x = W 0 x ( 因 A = 0 ) y = W_0 x + (A B) x = W_0 x \quad (\text{因 } A = 0) y=W0x+(AB)x=W0x(因 A=0) - 反向传播:
- 对 (
A
A
A ) 的梯度:
∂ L ∂ A = x T ⋅ ∂ L ∂ y ⋅ B T \frac{\partial L}{\partial A} = x^T \cdot \frac{\partial L}{\partial y} \cdot B^T ∂A∂L=xT⋅∂y∂L⋅BT- ( B B B ) 非零(高斯初始化),
- ( ∂ L ∂ y \frac{\partial L}{\partial y} ∂y∂L) 通常非零,
- ( ∂ L ∂ A ≠ 0 \frac{\partial L}{\partial A} \neq 0 ∂A∂L=0),( A A A ) 开始更新。
- 对 (
B
B
B ) 的梯度:
∂ L ∂ B = A T ⋅ ∂ L ∂ y \frac{\partial L}{\partial B} = A^T \cdot \frac{\partial L}{\partial y} ∂B∂L=AT⋅∂y∂L- ( A = 0 A = 0 A=0 ),( A T = 0 A^T = 0 AT=0 ),
- ( ∂ L ∂ B = 0 ⋅ ∂ L ∂ y = 0 \frac{\partial L}{\partial B} = 0 \cdot \frac{\partial L}{\partial y} = 0 ∂B∂L=0⋅∂y∂L=0),
- 初始时 ( B B B ) 的梯度为零,不更新。
- 对 (
A
A
A ) 的梯度:
- 动态:
- ( A A A ) 先获得非零梯度并更新,脱离全零状态。
- ( A A A ) 更新后,( ∂ L ∂ B \frac{\partial L}{\partial B} ∂B∂L) 变得非零,带动 ( B B B ) 调整。
- ( Δ W = A B \Delta W = A B ΔW=AB) 逐渐偏离零。
四、公式来源总结
- (
∂
L
∂
A
=
x
T
⋅
∂
L
∂
y
⋅
B
T
\frac{\partial L}{\partial A} = x^T \cdot \frac{\partial L}{\partial y} \cdot B^T
∂A∂L=xT⋅∂y∂L⋅BT):
- 从 (
y
=
(
x
A
)
B
y = (x A) B
y=(xA)B ) 推导:
- ( h = x A h = x A h=xA ) 对 ( A A A ) 的导数涉及 ( x x x ),
- ( y = h B y = h B y=hB ) 对 ( h h h ) 的导数涉及 ( B B B ),
- 链式法则合并得到此形式。
- 从 (
y
=
(
x
A
)
B
y = (x A) B
y=(xA)B ) 推导:
- (
∂
L
∂
B
=
A
T
⋅
∂
L
∂
y
\frac{\partial L}{\partial B} = A^T \cdot \frac{\partial L}{\partial y}
∂B∂L=AT⋅∂y∂L):
- (
y
=
h
B
y = h B
y=hB ) 对 (
B
B
B ) 的导数涉及 (
h
=
x
A
h = x A
h=xA ),
- 代入 ( h h h ) 后简化得到。
- (
y
=
h
B
y = h B
y=hB ) 对 (
B
B
B ) 的导数涉及 (
h
=
x
A
h = x A
h=xA ),
这些公式是矩阵微分和链式法则的直接结果,与 PEFT 源码中的自动求导一致。
五、验证与代码示例
以下是 PyTorch 验证代码:
import torch
x = torch.randn(2, 4) # batch=2, d=4
A = torch.zeros(4, 3) # d=4, r=3
B = torch.randn(3, 5) # r=3, k=5
A.requires_grad = True
B.requires_grad = True
y = torch.matmul(torch.matmul(x, A), B) # (x A) B
loss = y.sum() # 简单损失函数
loss.backward()
print("dL/dA:", A.grad) # 非零,因为 B 非零
print("dL/dB:", B.grad) # 零,因为 A = 0
scaling = lora_alpha / r
解释
LoRA 中 lora_alpha
和 scaling = lora_alpha / r
的作用与优化。我们会从其定义和作用入手,分析常见数值选择的意义,并探讨自适应缩放机制的可能性。
一、lora_alpha
和 scaling
的作用
在 LoRA 的实现中(见 peft.tuners.lora.LoraLinear.forward
),scaling
用于调整低秩更新 (\Delta W = A B) 的幅度。具体代码如下:
class LoraLinear(nn.Module):
def __init__(self, base_layer, config):
self.lora_alpha = config.lora_alpha # 超参数,例如 32
self.r = config.r # 秩,例如 16
self.scaling = self.lora_alpha / self.r # scaling = 32 / 16 = 2
self.lora_A = nn.Parameter(torch.randn(in_features, r))
self.lora_B = nn.Parameter(torch.zeros(r, out_features))
def forward(self, x):
base_output = self.base_layer(x)
lora_output = torch.matmul(torch.matmul(x, self.lora_A), self.lora_B)
return base_output + lora_output * self.scaling
-
定义:
lora_alpha
是一个超参数,通常由用户指定(常见值如 16、32)。scaling = lora_alpha / r
是实际应用于 (\Delta W x) 的缩放因子。- 前向传播中,输出为 ( y = W 0 x + ( A B ) x ⋅ scaling y = W_0 x + (A B) x \cdot \text{scaling} y=W0x+(AB)x⋅scaling )。
-
作用:
- 幅度调整:(
Δ
W
=
A
B
\Delta W = A B
ΔW=AB) 是低秩矩阵,其数值范围受 (
A
A
A ) 和 (
B
B
B ) 初始化及训练的影响。
scaling
确保 ( Δ W \Delta W ΔW) 的贡献与原始权重 ( W 0 W_0 W0 ) 的幅度匹配,避免更新过小(影响微调效果)或过大(破坏预训练知识)。 - 秩的归一化:(
r
r
r )(秩)决定了 (
Δ
W
\Delta W
ΔW) 的表达能力,(
r
r
r ) 越小,更新矩阵的自由度越低。
scaling = lora_alpha / r
通过除以 ( r r r ) 抵消秩的影响,使不同 ( r r r ) 的模型具有相似的更新幅度。 - 超参数解耦:将幅度控制从 ( A A A ) 和 ( B B B ) 的初始化中分离出来,便于手动调参。
- 幅度调整:(
Δ
W
=
A
B
\Delta W = A B
ΔW=AB) 是低秩矩阵,其数值范围受 (
A
A
A ) 和 (
B
B
B ) 初始化及训练的影响。
二、常见数值选择(16 或 32)的意义
1. 为什么是 16 或 32?
- 经验值:在 LoRA 的原始论文(Hu et al., 2021)和后续实践中,
lora_alpha
的常见取值(如 16、32)是通过实验确定的。这些值在多种任务和模型(如 Transformer、LLaMA)上表现良好。 - 与 (
r
r
r ) 的关系:
- ( r r r ) 通常取较小的值(例如 8、16、32),表示低秩分解的维度。
scaling = lora_alpha / r
通常落在 1~4 之间(例如 ( 16 / 8 = 2 16/8 = 2 16/8=2 ), ( 32 / 16 = 2 32/16 = 2 32/16=2 )),这是一个适中的缩放范围。- 如果 (
r
=
16
r = 16
r=16 ),
lora_alpha = 32
得到scaling = 2
,意味着 ( Δ W x \Delta W x ΔWx) 的贡献是原始计算的两倍;如果 ( r = 8 r = 8 r=8 ),则scaling = 4
。
- 数值稳定性:这些值避免了过大的缩放(可能导致梯度爆炸)或过小的缩放(削弱微调效果),与深度学习中常见的权重初始化(如标准差 0.02~0.1)相匹配。
2. 看法
- 优点:
- 固定值(如 16 或 32)简单易用,减少了调参负担,尤其对于初学者或标准化流程。
- 与 ( r r r ) 的比值设计考虑了秩的限制,确保不同配置下的一致性。
- 局限性:
- 缺乏理论依据:为什么是 32 而不是 30 或 40?这些值的选择更多是经验性的,缺乏明确的数学推导或任务自适应的依据。
- 任务依赖性:不同任务(例如分类 vs. 生成)或模型规模(7B vs. 70B 参数)可能需要不同的缩放幅度,固定值可能不总是最优。
- 初始化敏感性:
scaling
的效果与 ( A A A ) 和 ( B B B ) 的初始化(如高斯分布的标准差)强相关,若初始化不当,可能需要调整lora_alpha
来补偿。
三、自适应缩放机制的可能性
固定 lora_alpha
的设计虽然实用,但确实存在优化空间。自适应缩放可以根据模型状态、任务需求或训练过程动态调整幅度。以下是几种可能的思路:
1. 基于权重范数的自适应缩放
- 思路:
- 计算预训练权重 ( W 0 W_0 W0 ) 的范数(如 Frobenius 范数 ( ∣ ∣ W 0 ∣ ∣ F ||W_0||_F ∣∣W0∣∣F ))和 ( Δ W = A B \Delta W = A B ΔW=AB) 的初始范数。
- 动态调整
scaling
,使 ( Δ W \Delta W ΔW) 的贡献与 ( W 0 W_0 W0 ) 的幅度成比例:
scaling = α ⋅ ∣ ∣ W 0 ∣ ∣ F ∣ ∣ A B ∣ ∣ F \text{scaling} = \alpha \cdot \frac{||W_0||_F}{||A B||_F} scaling=α⋅∣∣AB∣∣F∣∣W0∣∣F
其中 ( α \alpha α) 是一个小的常数(如 0.1 或 1)。
- 优点:
- 自动适配不同层的权重规模(例如注意力层 vs. 前馈层)。
- 避免手动选择
lora_alpha
。
- 挑战:
- ( ∣ ∣ A B ∣ ∣ F ||A B||_F ∣∣AB∣∣F ) 在训练中变化,需实时计算或定期更新,增加计算开销。
- 初始 ( B = 0 B = 0 B=0 ) 时范数为零,需特殊处理。
2. 基于梯度统计的自适应缩放
- 思路:
- 在训练初期,统计 ( ∂ L ∂ y \frac{\partial L}{\partial y} ∂y∂L)(损失对输出的梯度)和 ( ∂ L ∂ ( A B ) \frac{\partial L}{\partial (A B)} ∂(AB)∂L) 的梯度范数。
- 调整
scaling
,使 LoRA 部分的梯度与 ( W 0 W_0 W0 ) 的梯度幅度匹配:
scaling = β ⋅ ∣ ∣ ∇ W 0 L ∣ ∣ ∣ ∣ ∇ ( A B ) L ∣ ∣ \text{scaling} = \beta \cdot \frac{||\nabla_{W_0} L||}{||\nabla_{(A B)} L||} scaling=β⋅∣∣∇(AB)L∣∣∣∣∇W0L∣∣
(( β \beta β) 为超参数)。
- 优点:
- 反映任务的实际需求,动态适应梯度分布。
- 挑战:
- 需要额外的梯度计算,可能影响训练效率。
- 早期梯度可能不稳定,需平滑处理(如指数移动平均)。
3. 基于秩的自适应调整
- 思路:
- 当前
scaling = lora_alpha / r
已考虑秩,但可以进一步优化。例如,使用非线性函数:
scaling = l o r a a l p h a r 或 l o r a a l p h a log ( r + 1 ) \text{scaling} = \frac{lora_alpha}{\sqrt{r}} \quad \text{或} \quad \frac{lora_alpha}{\log(r + 1)} scaling=rloraalpha或log(r+1)loraalpha - 或者根据 (
r
r
r ) 相对于 (
d
d
d ) 和 (
k
k
k ) 的比例动态调整:
scaling = l o r a a l p h a ⋅ min ( d , k ) r \text{scaling} = lora_alpha \cdot \frac{\min(d, k)}{r} scaling=loraalpha⋅rmin(d,k)
- 当前
- 优点:
- 更好地平衡秩的限制与更新幅度。
- 无需额外计算,简单可行。
- 挑战:
- 仍需手动设置
lora_alpha
,未完全自适应。
- 仍需手动设置
4. 学习缩放因子
- 思路:
- 将
scaling
视为可训练参数,而不是固定值。在LoraLinear
中定义:self.scaling = nn.Parameter(torch.tensor(1.0))
- 通过梯度下降优化
scaling
,使其自动适配任务。
- 将
- 优点:
- 完全自适应,无需手动调参。
- 与 ( A A A ) 和 ( B B B ) 联合优化,可能找到全局最优解。
- 挑战:
- 增加了一个可训练参数,可能放大训练不稳定性。
- 需要合适的初始化和正则化(例如限制
scaling
的范围)。
四、实现示例:自适应缩放
以下是一个基于梯度统计的自适应缩放的简单实现:
class AdaptiveLoraLinear(nn.Module):
def __init__(self, base_layer, config):
super().__init__()
self.base_layer = base_layer
self.r = config.r
self.lora_A = nn.Parameter(torch.randn(base_layer.in_features, r))
self.lora_B = nn.Parameter(torch.zeros(r, base_layer.out_features))
self.scaling = 1.0 # 初始值
self.register_buffer("ema_grad", torch.tensor(0.0)) # 梯度统计
def update_scaling(self, grad_output):
lora_grad = torch.matmul(torch.matmul(grad_output, self.lora_B.T), self.lora_A.T)
grad_norm = lora_grad.norm()
self.ema_grad = 0.9 * self.ema_grad + 0.1 * grad_norm # 平滑更新
self.scaling = 1.0 / (self.ema_grad + 1e-8) # 动态调整
def forward(self, x):
base_output = self.base_layer(x)
lora_output = torch.matmul(torch.matmul(x, self.lora_A), self.lora_B)
return base_output + lora_output * self.scaling
- 逻辑:根据 LoRA 部分的梯度范数动态调整
scaling
,使其与输出梯度反比。 - 效果:避免手动设置
lora_alpha
,但需在训练中调用update_scaling
。
五、总结与看法
- 数值选择(16 或 32):
- 这些值是经验调参的结果,简单实用,但在理论上缺乏严谨依据。认为它们适用于大多数场景,但对极端任务(如超大数据集或超小模型)可能需要调整。
- 自适应缩放:
- 更倾向于基于梯度或权重视的自适应机制,因为它们能反映模型和任务的实际需求。
- 学习
scaling
是一个有趣的方向,但需要仔细设计正则化以避免过拟合。
参考
https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms
后记
2025年3月11日22点22分于上海,在Grok 3大模型辅助下完成。