目录
参数高效微调:LoRA、Adapter与Prompt Tuning实战
5. LoRA、Adapter与Prompt Tuning的对比
随着大规模预训练语言模型(如BERT、GPT等)的普及和深度学习计算资源的消耗,如何高效微调这些巨型模型成为了研究的热点。传统的微调方式需要调整大量参数,导致显存消耗和计算开销较大,尤其在多任务学习或在资源受限的环境中,这种方式可能变得不切实际。为了解决这一问题,研究者们提出了几种参数高效的微调方法,如LoRA(Low-Rank Adaptation)、Adapter与Prompt Tuning,这些方法可以有效减少微调时的显存占用,甚至达到减少90%显存的效果。
本文将详细介绍这三种技术,并通过代码示例展示如何在实际应用中实现低成本的微调。
1. 为什么需要参数高效的微调方法?
在传统的微调方法中,通常会直接对预训练模型的所有参数进行更新,这样会导致以下问题:
- 显存消耗大:大规模模型的参数量非常庞大,更新所有参数会消耗大量显存。
- 训练效率低:每次训练都需要从头到尾处理整个模型的参数,导致训练时间过长。
- 计算资源浪费:对于单一任务或少量任务而言,微调整个模型可能是不必要的,浪费了计算资源。
因此,开发参数高效的微调方法,能够减少微调时的显存占用并提高训练效率。
2. LoRA(Low-Rank Adaptation)
LoRA(低秩适应)是一种高效的微调方法,其核心思想是通过在原始模型的层之间引入低秩矩阵,来减少微调时需要更新的参数量。LoRA通过将预训练模型中部分权重矩阵分解为低秩矩阵,使得只需对低秩矩阵进行更新,而不是整个模型。
2.1 原理
- 传统的微调方法会更新模型的所有权重参数,而LoRA将每一层的权重矩阵
分解为两个低秩矩阵
和
,即
,其中
和
是低秩矩阵,低秩矩阵的维度远小于原始矩阵。
- 在训练时,仅更新
和
,这大大减少了需要更新的参数数量。
公式:假设原始模型的权重矩阵为,LoRA通过低秩矩阵
和
来近似
。其中,
表示低秩矩阵的秩,通常远小于
的维度
。
2.2 代码示例(LoRA微调)
以下是使用LoRA对BERT进行微调的简化示例代码:
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import torch.nn as nn
class LoRA_BERT(nn.Module):
def __init__(self, model_name='bert-base-uncased', rank=4):
super(LoRA_BERT, self).__init__()
self.bert = BertForSequenceClassification.from_pretrained(model_name)
self.rank = rank
# 插入低秩适应层
self.lora_layers = nn.ModuleList([
nn.Linear(self.bert.config.hidden_size, rank, bias=False),
nn.Linear(rank, self.bert.config.hidden_size, bias=False)
])
def forward(self, input_ids, attention_mask=None, labels=None):
# 获取BERT输出
outputs = self.bert.bert(input_ids, attention_mask=attention_mask)
hidden_states = outputs[0]
# 低秩适应
lora_output = self.lora_layers[0](hidden_states)
lora_output = self.lora_layers[1](lora_output)
# 将LoRA层的输出与BERT输出结合
final_output = hidden_states + lora_output
# 分类层
logits = self.bert.classifier(final_output)
return logits
# 初始化LoRA模型
model = LoRA_BERT()
# 示例输入
input_ids = torch.tensor([[101, 2054, 2003, 1996, 2196, 102]]) # 示例输入文本(转化为ID)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
# 获取预测
outputs = model(input_ids, attention_mask)
print(outputs)
在这个示例中,LoRA_BERT
模型在原始BERT模型的基础上添加了两个低秩矩阵(和
),通过这些矩阵对BERT的输出进行调整,而不会更新整个模型的权重。这样,LoRA显著降低了显存消耗。
3. Adapter
Adapter方法是一种微调技术,它通过在预训练模型的各层之间插入轻量级的适配器(Adapter)模块来减少需要更新的参数量。每个适配器模块通常包含两个全连接层,通过这些模块可以有效地调整模型的行为。
3.1 原理
- Adapter方法通过在每个Transformer层的输入和输出之间插入一个小型的神经网络(通常是一个简单的全连接层),仅更新这些适配器参数,而不更新原始模型的权重。
- 适配器的大小通常很小,因此大大减少了需要更新的参数量。
公式:假设一个Transformer层的输出为hh,通过适配器模块的操作可以表示为:
其中,是适配器模块的变换函数,
是适配器的参数。
3.2 代码示例(Adapter微调)
以下是一个基于HuggingFace库实现的Adapter微调代码示例:
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from transformers import AdapterConfig, AdapterType
# 加载预训练BERT
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# 添加适配器
adapter_name = "my_adapter"
config = AdapterConfig.load("pfeiffer") # 使用Pfeiffer适配器策略
model.add_adapter(adapter_name, config=config)
# 启用适配器训练
model.train_adapter(adapter_name)
# 示例输入
input_ids = torch.tensor([[101, 2054, 2003, 1996, 2196, 102]]) # 示例输入文本(转化为ID)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
# 获取预测
outputs = model(input_ids, attention_mask)
print(outputs.logits)
在这个示例中,使用了一个适配器来微调BERT模型,适配器的更新仅影响添加的模块,而不需要重新训练整个模型。
4. Prompt Tuning
Prompt Tuning是一种新兴的微调技术,特别适用于像GPT这类生成式模型。其基本思想是在输入文本中加入一些特定的提示(Prompt),并通过学习这些提示来调整模型的行为。与传统微调方法不同,Prompt Tuning只需调整输入的提示词而不涉及模型内部权重的更新。
4.1 原理
- 在Prompt Tuning中,模型的权重保持不变,仅通过学习如何构造有效的提示(Prompt)来调整模型的输出。
- 通过优化提示的词向量,可以使模型在特定任务中表现更好。
公式:在Prompt Tuning中,我们将输入文本转化为带有优化过的提示的形式
,其中
是优化得到的提示。
4.2 代码示例(Prompt Tuning)
以下是一个基于HuggingFace的GPT-2模型的Prompt Tuning代码示例:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
# 加载GPT2模型和Tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 示例输入和Prompt
prompt = "Translate the following English sentence to French: 'Hello, how are you?'"
inputs = tokenizer(prompt, return_tensors="pt")
# 获取GPT2的输出
outputs = model.generate(inputs["input_ids"], max_length=50)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)
在这个示例中,通过Prompt Tuning,只需要通过设计有效的提示词来引导GPT-2模型进行特定任务的生成,而不需要对模型的权重进行微调。
5. LoRA、Adapter与Prompt Tuning的对比
为了帮助大家更清晰地理解这三种参数高效微调技术,我们通过以下表格对它们进行对比:
特性 | LoRA | Adapter | Prompt Tuning |
---|---|---|---|
原理 | 引入低秩矩阵进行微调 | 在每个层之间插入适配器模块 | 通过优化提示词(Prompt)调整模型行为 |
参数更新 | 只更新低秩矩阵(AA和BB) | 只更新适配器模块的参数 | 只更新提示词(Prompt)的参数 |
显存占用 | 极大减少显存占用 | 显存占用较少 | 显存占用最少 |
计算开销 | 计算开销相对较小 | 计算开销较小 | 计算开销最小 |
适用场景 | 多任务学习,文本分类 | 特定任务微调,尤其是低资源环境 | 生成任务、对话系统 |
实现复杂度 | 较高,涉及矩阵分解 | 中等,需要添加适配器 | 低,只需要优化提示词 |
6. 总结
在深度学习模型的微调过程中,LoRA、Adapter和Prompt Tuning是三种有效的参数高效微调方法。这些方法能够显著降低显存消耗,提高训练效率,同时在多任务学习和低资源环境中表现出色。选择哪种方法取决于任务的类型、资源的限制以及模型的特点。希望本文能够帮助大家理解并应用这些技术来优化模型训练。如果有任何问题,欢迎留言讨论!
推荐阅读:
手把手搭建你的第一个大模型:基于HuggingFace的模型微调-CSDN博客