以下是 DeepSeek-7B 的 LoRA 微调代码示例,结构清晰且可直接运行:
环境配置
pip install transformers peft datasets accelerate bitsandbytes
代码实现
1. 模型与数据加载
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
# 加载模型与分词器
model = AutoModelForCausalLM.from_pretrained(
"deepseek-ai/deepseek-llm-7b-chat",
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-llm-7b-chat", trust_remote_code=True)
# 示例数据(JSON格式)
data = [{
"instruction": "你是一个助手",
"input": "解释量子计算",
"output": "量子计算利用量子比特的叠加态进行并行计算..."
}]
2. 数据预处理
def format_data(example):
text = f"User: {example['instruction']} {example['input']}\n\nAssistant: {example['output']}"
return tokenizer(text, truncation=True, max_length=512)
dataset = Dataset.from_dict({"text": [d["input"] + d["output"] for d in data]})
dataset = dataset.map(format_data, batched=True)
3. 配置 LoRA
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"], # 目标注意力层
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters() # 输出示例:0.3% 参数可训练
4. 训练参数设置
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
num_train_epochs=3,
learning_rate=2e-5,
fp16=True,
logging_steps=10,
save_strategy="steps",
save_steps=100
)
5. 启动训练
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer
)
trainer.train()
推理测试
input_text = "User: 解释量子计算\n\nAssistant:"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
关键参数说明
参数 | 作用 | 推荐值 |
---|---|---|
target_modules | 指定注入LoRA的注意力层 | [“q_proj”, “v_proj”] |
r | 控制模型适配能力的秩 | 8-64(通常8足够) |
per_device_train_batch_size | 根据GPU显存调整(24G显存可设4) | 2-8 |
显存优化技巧
- 4-bit量化(需12GB显存):
model = AutoModelForCausalLM.from_pretrained(..., load_in_4bit=True)
- 梯度检查点:
model.gradient_checkpointing_enable()