LoRA构建:利用数学知识进行低阶自适应分析并在 PyTorch 中实现

公众号:Halo 咯咯

本文中将介绍了解 LoRA 是什么,并用数学原理知识来描述 LoRA 有效微调大型模型,最后从头开始创建我们自己的 LoRA 并使用它来微调我们的模型。

LoRA是如何工作的?

LLM(Large Language Models,大型语言模型)和其他类似的先进模型,例如稳定扩散模型,通常拥有数十亿个参数,这使得它们在处理复杂的人工智能任务时表现出色。然而,这种规模的模型需要庞大的预算和计算资源才能进行有效的微调,以适应特定的业务场景。
为了解决这一挑战,微软在其研究论文《LoRA: Low Rank Adaptation of LLMs》中提出了一种创新的方法——LoRA。LoRA的核心思想是优化微调过程,减少对计算资源的需求。
在传统的微调过程中,通常需要将整个模型加载至GPU,并执行反向传播算法来更新模型的所有权重。但LoRA采取了一种不同的路径。它通过冻结原始模型的初始权重W,并引入两个额外的低秩矩阵A和B来实现微调。这两个矩阵的乘积将生成一个新的权重矩阵,其维度与原始权重矩阵W相同。
在训练过程中,只有LoRA矩阵参与反向传播,而原始模型的权重保持不变。这样,LoRA大幅减少了在微调过程中需要更新的参数数量,从而降低了对计算资源的需求。这种方法不仅提高了效率,还使得在有限的资源下对大型模型进行微调成为可能,为各类企业打开了利用先进AI技术的大门。

正如之前提到的,LoRA(Low Rank Adaptation)的核心技术机制是基于矩阵分解的。这种方法允许我们在训练过程中以一种高效的方式来调整和优化模型的权重。

LoRA的数学原理

在LoRA的前向传递过程中,我们使用以下公式来计算隐藏层h:

h = W0 + ΔW

其中,h代表的是模型的隐藏层表示,W0是预训练模型中冻结的原始权重矩阵,其形状为(d x k),这里的d表示特征的维度,k表示模型的容量或者隐藏单元的数量。ΔW则是LoRA技术所特有的权重矩阵,它通过低秩矩阵乘法来近似原始权重矩阵W0的更新。

LoRA通过引入两个新的矩阵B和A来实现这种权重的近似更新。这两个矩阵的维度分别是(d x r)和(r x k),其中r是低秩矩阵的秩,通常远小于原始权重矩阵W0的维度。通过这种方式,B和A的乘积将产生一个新的权重矩阵,其形状与W0相同,即(d x k)。

具体来说,B矩阵可以被视为原始权重矩阵W0的列空间的一个压缩表示,而A矩阵则代表了这一压缩表示到原始权重空间的映射。通过这种方式,LoRA能够在保持模型性能的同时,显著减少在微调过程中需要更新的参数数量,从而降低了对计算资源的需求。

在实际应用中,这意味着我们可以在有限的计算资源下,对大型预训练模型进行有效的微调,以适应特定的业务需求或数据集。这种方法不仅提高了模型的适应性和灵活性,也为AI技术的广泛应用开辟了新的可能性。

正如我们在图中看到的,要创建一个具有更新权重和 9 个参数的 (3x3) 矩阵,我们只需要更新 6 个参数,即 B 和 A 各 3 个。对于这个特定矩阵来说,这似乎不是一个重要的数字,但在使用实际模型时,可训练参数的数量会急剧减少。

最近有一些发展,例如加州大学伯克利分校的 Soufiane Hayou、Nikhil Ghosh 和 Bin Yu 提出的 LoRA+。 LoRA+ 与 LoRA 相同,但 B 和 A 矩阵的学习率不同。

现在我们已经熟悉了 LoRA 及其内部工作原理,让我们从头开始构建一个 LoRA 并加深我们的理解。

#Let's create a LoRA class that will add two new matrices A and B to the
#original weights and return them such as B x A yields the same dinmensions as W

class LoRA(nn.Module):
  def __init__(self, features_in, features_out, rank, alpha, device = device):
    super().__init__()
    self.matrix_A = nn.Parameter(torch.zeros((features_out, rank)).to(device))
    self.matrix_B = nn.Parameter(torch.zeros((rank, features_in)).to(device))
    self.scale = alpha/rank
    nn.init.normal_(matrix_A, mean = 0, std = 1)  

  def forward(self, W):
    return W + torch.matmul(self.matrix_B,self.matrix_A).view(W.shape)*self.scale

现在我们已经创建了自己的 LoRA 类,它将在原始权重之上设置矩阵,我们需要创建一个函数,用 LoRA 类的输出替换层中的原始权重,即添加两个新矩阵在原始权重矩阵之上。

#This function takes the layer as the input and sets the features_in.features_out
#equal to the shape of the weight matrix. This will help the LoRA class to
#initialize the A and B Matrices

def layer_parametrization(layer, device, rank = 1, lora_alpha = 1):
  features_in, features_out = layer.weight.shape
  return LoRA(features_in, features_out, rank = rank, alpha = lora_alpha, device)

我们可以使用 PyTorch 库中的 Parametrize() 函数轻松地将此函数应用到模型层。

我们从头开始成功创建了 LoRA。唯一剩下的就是它的实际实施和实践经验以及 LoRA 的效率。为了方便起见,我们将训练一个模型来对 MNIST 数字进行分类,并针对特定数字对模型进行微调。如需完整代码,请访问我的 GitHub 存储库。

未经任何微调的原始分类器给出以下结果:

import torch.nn.utils.parametrize as parametrize

# Here we apply parametrization such that whenever the model wants to access 
# weights from the original linear layers of the model, 
# it returns the original weights plus LoRA matrices so that we can freeze the
# original weights and then train the LoRA matrices.

parametrize.register_parametrization(exp.linear1, 'weight', layer_parametrization(exp.linear1, device))
parametrize.register_parametrization(exp.linear2, 'weight', layer_parametrization(exp.linear2, device))
parametrize.register_parametrization(exp.linear3, 'weight', layer_parametrization(exp.linear3, device))

这里是 LoRA 引入的额外参数数量(仅供参考)。我们可以看到仅增加了 0.242% 的参数,使得 LoRA 的微调变得高效。

这里只附上了主要片段。

现在,让我们微调模型以提高其对数字 7 进行分类的准确性。我们将再次上传仅包含数字 7 的 MNIST 数据集并训练一个新模型,但这次使用 LoRA。

#freezing the non-LoRA matrices.

for name, param in exp.named_parameters():
    if 'mat' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

for layer in [exp.linear1, exp.linear2, exp.linear3]:
  layer.parametrizations["weight"][0].requires_grad = True

# Train the network with LoRA only on the digit 7 and only for 100 batches 
train(train_loader, exp, epochs=1, total_iterations_limits=100)

结果如下:

数字 7 的错误计数从 67 个减少到仅 5 个。这确实是一个很大的改进,而且附加参数很少。

参考:

如果觉得文章对你有用,欢迎大家点赞+关注。

  • 16
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Halo 咯咯

有你的支持我会更加努力。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值