大模型训练与微调(6)——微调之 Prompt Tuning 详解
今天学了Prompt Tuning,记录一下!以下是 Prompt Tuning 的具体方法详解,涵盖其核心思想、实现步骤、优化策略及实践建议:
一、Prompt Tuning 的核心思想
Prompt Tuning 是一种参数高效微调方法(PEFT),通过在输入序列中插入可学习的连续提示向量(Prompt Embeddings),引导预训练模型适配下游任务。与传统的全参数微调不同,它仅更新提示向量,冻结模型主体参数,显著降低计算成本。
关键特点:
- 参数高效:仅需调整少量提示参数(通常占总参数的0.1%-1%);
- 任务适配灵活:不同任务可独立训练提示向量,共享同一预训练模型;
- 无需修改模型结构:仅通过输入序列扩展实现。
二、Prompt Tuning 的典型实现步骤
1. 提示向量的插入方式
-
输入序列构造:
在原始输入文本前拼接一组可学习的连续向量(即提示向量):[Prompt Embeddings] + [Input Text]
例如,对于分类任务,输入可能变为:
[P1][P2]...[Pn] + "这部电影的评分是:[MASK]"
-
位置扩展:
- 仅输入层:提示向量仅加在输入层的嵌入中(如Google提出的原始Prompt Tuning);
- 多层扩展(如Prefix-Tuning):在每一层Transformer的注意力模块前拼接提示向量,增强深层引导能力。
2. 参数冻结与优化
- 冻结模型参数:保持预训练模型的权重不变(仅训练提示向量);
- 优化目标:通过梯度下降更新提示向量,使其编码任务相关的知识,激活模型的预训练能力。
3. 提示向量的初始化
- 随机初始化:直接随机初始化(需更多训练数据);
- 语义初始化:用任务相关词汇的嵌入初始化(如分类任务用“好/坏”等词的嵌入),加速收敛。
三、Prompt Tuning 的具体变体
1. 基础 Prompt Tuning
- 适用场景:文本分类、生成任务(如T5模型);
- 实现方式:
仅在输入层添加提示向量,模型前向传播时将其与输入文本的嵌入拼接,通过注意力机制引导模型输出。
2. Prefix Tuning
- 改进点:
- 将提示向量插入每一层Transformer的Key-Value注意力模块前(而非仅输入层);
- 使用更复杂的编码方式(如MLP生成前缀向量,增强表示能力)。
- 优势:对深层特征的引导更充分,适合复杂生成任务(如对话、摘要)。
3. P-Tuning v2
- 特点:
- 结合Prompt Tuning和Prefix Tuning,支持每层独立的提示向量;
- 无需依赖外部模型(如LSTM)生成提示,直接优化连续向量。
四、具体实现细节(以PyTorch和HuggingFace为例)
1. 代码框架
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# 加载预训练模型和分词器
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# 定义可学习的提示向量(提示长度=5,隐藏层维度=768)
prompt_length = 5
hidden_size = model.config.hidden_size
prompt_embeddings = torch.nn.Parameter(torch.randn(prompt_length, hidden_size))
# 冻结模型参数
for param in model.parameters():
param.requires_grad = False
# 修改前向传播逻辑
def forward(input_ids, attention_mask):
# 获取原始输入的嵌入
input_embeds = model.bert.embeddings(input_ids) # BERT模型的嵌入层
# 拼接提示向量(扩展至batch维度)
batch_size = input_ids.size(0)
prompt_embeds = prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
combined_embeds = torch.cat([prompt_embeds, input_embeds], dim=1)
# 调整attention_mask以包含提示部分
prompt_mask = torch.ones(batch_size, prompt_length).to(input_ids.device)
combined_mask = torch.cat([prompt_mask, attention_mask], dim=1)
# 通过模型主体
outputs = model.bert(
inputs_embeds=combined_embeds,
attention_mask=combined_mask
)
pooled_output = outputs.last_hidden_state[:, 0, :] # 取[CLS]向量
# 分类头
logits = model.classifier(pooled_output)
return logits
2. 训练流程
- 数据预处理:将任务输入转换为模型接受的格式(如分类任务添加[MASK]);
- 前向传播:拼接提示向量,扩展注意力掩码;
- 损失计算:根据任务类型选择损失函数(如交叉熵损失);
- 反向传播:仅更新提示向量和分类头参数;
- 保存与部署:仅需保存提示向量(极小的文件)。
五、Prompt Tuning的优缺点
优点 | 缺点 |
---|---|
1. 参数高效(仅训练0.1%-1%参数) | 1. 提示长度敏感(需实验确定最佳长度) |
2. 适配多任务(共享模型主体) | 2. 训练稳定性较低(依赖初始化与学习率) |
3. 无需修改模型结构 | 3. 生成任务可能冗余(需设计合适的提示模板) |
4. 适合少样本场景 | 4. 性能上限低于全参数微调 |
六、优化策略
1. 提示长度选择
- 分类任务:5-20个token;
- 生成任务:20-50个token(或逐层添加);
- 实验建议:通过网格搜索确定最优长度。
2. 初始化方法
- 随机初始化:适合数据量较大的场景;
- 语义初始化:用任务关键词的嵌入初始化(如情感分析用“好/坏”的嵌入);
- 任务相关初始化:用相似任务的预训练提示向量迁移。
3. 学习率设置
- 提示向量的学习率通常高于分类头(如提示向量用1e-4,分类头用1e-5);
- 使用分层学习率:深层提示向量学习率更低(如Prefix-Tuning)。
七、适用场景与案例
1. 典型场景
- 少样本学习(Few-shot):数据量不足时快速适配;
- 多任务学习:不同任务独立训练提示向量,共享模型主体;
- 资源受限环境:显存不足或需快速部署的场景。
2. 案例参考
- 文本分类:情感分析、主题分类;
- 生成任务:对话生成(如GPT-3 + Prompt Tuning);
- 序列标注:命名实体识别(通过[MASK]预测实体类型)。
八、与其他方法的对比
方法 | 参数更新量 | 是否需要修改模型 | 适用任务 |
---|---|---|---|
Prompt Tuning | 0.1%-1% | 否 | 分类、生成 |
LoRA | 0.1%-1% | 否(低秩分解) | 通用任务 |
Adapter | 1%-5% | 是(插入模块) | 多任务适配 |
P-Tuning v2 | 0.1%-3% | 否(多层提示) | 复杂任务、小模型 |
九、实践建议
- 少样本任务:优先尝试语义初始化提示向量;
- 生成任务:使用Prefix-Tuning(每层添加提示);
- 性能瓶颈:若效果不足,可结合LoRA(同时训练提示向量和低秩矩阵);
- 调参重点:关注提示长度、学习率、初始化策略。
通过灵活应用Prompt Tuning,可以在极低计算成本下实现大模型的高效适配,尤其适合资源受限或需要快速迭代的多任务场景。