get_peft_model
是 PEFT (Parameter-Efficient Fine-Tuning) 框架中的一个核心函数,通常用于加载或创建一个可以高效微调的模型,尤其适合在低资源场景或小型数据集上进行模型微调。PEFT 框架支持的技术包括 LoRA (Low-Rank Adaptation)、Prefix Tuning、Prompt Tuning 等。
下面是如何使用 get_peft_model
的示例:
安装所需库
bash
复制代码
pip install peft transformers
示例代码
python
复制代码
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
# 加载基础模型和分词器
model_name = "gpt2" # 可替换为其他模型,例如 "bert-base-uncased"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 配置 PEFT 微调(以 LoRA 为例)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # 任务类型(如文本生成、分类等)
inference_mode=False, # 是否仅用于推理
r=8, # LoRA 矩阵的秩
lora_alpha=32, # LoRA 缩放因子
lora_dropout=0.1 # Dropout 概率
)
# 应用 PEFT 配置到模型
peft_model = get_peft_model(model, peft_config)
# 检查模型结构
print(peft_model)
# 示例微调
input_text = "Hello, how are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
labels = input_ids.clone()
outputs = peft_model(input_ids=input_ids, labels=labels)
loss = outputs.loss
print(f"Loss: {loss.item()}")
# 保存微调后的模型
peft_model.save_pretrained("peft_lora_model")
参数说明
task_type
: 指定任务类型,例如TaskType.CAUSAL_LM
(因果语言建模)、TaskType.SEQ_CLS
(序列分类)等。r
: LoRA 的秩,通常设为 4 或 8。lora_alpha
: 调整 LoRA 参数对模型输出的影响。lora_dropout
: 避免过拟合的 Dropout 概率。
PEFT 优势
- 参数高效:只更新少量新增参数,减少计算和存储开销。
- 低资源需求:适合小型数据集和低计算资源环境。
- 灵活性强:支持多种任务和模型。
如果有更具体的需求(例如其他 PEFT 方法或特定模型支持),可以进一步调整配置。