一、从 MySQL 提取训练数据
1. 数据查询与导出
import pandas as pd
import pymysql
# 连接 MySQL 数据库
conn = pymysql.connect(
host="localhost",
user="your_username",
password="your_password",
database="your_database"
)
# 执行 SQL 查询(示例:提取客服对话记录)
query = """
SELECT
user_query AS instruction,
context AS input,
agent_response AS output
FROM customer_service_logs
WHERE created_at > '2024-01-01'
"""
df = pd.read_sql(query, conn)
conn.close()
# 保存为中间文件(可选)
df.to_csv("training_data.csv", index=False)
2. 数据格式化(转换为 Alpaca 格式)
def format_alpaca(row):
text = f"### Instruction:\n{row['instruction']}\n### Input:\n{row['input']}\n### Response:\n{row['output']}"
return {"text": text}
dataset = Dataset.from_pandas(df)
dataset = dataset.map(format_alpaca)
二、环境配置与模型加载
1. 安装依赖
pip install unsloth pymysql pandas
2. 加载 DeepSeek 模型(4bit 量化)
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="deepseek-ai/deepseek-llm-7b-chat",
max_seq_length=2048,
load_in_4bit=True, # 16GB 显存即可运行
device_map="auto",
)
三、注入 LoRA 适配器
model = FastLanguageModel.get_peft_model(
model,
r=16, # 推荐值:8(低显存)至 64(高精度)
lora_alpha=32,
target_modules=["q_proj", "v_proj", "o_proj"],
use_gradient_checkpointing=True, # 长序列优化
)
四、训练参数配置
from transformers import TrainingArguments
trainer_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=2, # RTX 4090 建议设为4
gradient_accumulation_steps=8, # 等效增大 batch_size
num_train_epochs=3,
learning_rate=2e-4, # 推荐范围:1e-5 至 3e-4
fp16=torch.cuda.is_bf16_supported(),
logging_steps=10,
)
五、启动训练
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=2048,
tokenizer=tokenizer,
)
trainer.train() # 训练时长示例:1000条数据在 RTX 3090 上约需 45 分钟
六、模型验证与部署
1. 推理测试
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=200)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 测试业务场景问题
print(generate_response("用户:订单号 20240515001 的状态是什么?"))
# 预期输出:"订单 20240515001 当前状态为已发货,预计 3 个工作日内送达。"
2. 保存与部署
# 保存 LoRA 适配器
model.save_pretrained("./lora_weights")
# 转换为 Hugging Face 格式(可选)
model.push_to_hub("your-username/deepseek-mysql-finetuned")
七、关键优化技巧
1. MySQL 数据实时同步
- 通过 Debezium 监听数据库变更日志(CDC),实现增量数据自动导入训练流程:
from kafka import KafkaConsumer consumer = KafkaConsumer("mysql_customer_service", bootstrap_servers='localhost:9092') for msg in consumer: new_data = json.loads(msg.value) dataset.add_item(format_alpaca(new_data)) # 动态更新数据集
2. 性能调优参数
参数 | 推荐值 | 作用 |
---|---|---|
per_device_train_batch_size | 2-4 | 平衡显存占用与训练速度 |
r | 8-32 | 控制模型适配能力与显存消耗 |
gradient_accumulation_steps | 8-16 | 等效增大 batch_size 以降低显存需求 |
3. 安全增强
- 数据库连接使用 SSH 隧道 或 VPC 私有网络
- 敏感信息通过环境变量注入:
import os db_password = os.environ["MYSQL_PASSWORD"]
通过以上步骤,可高效利用 MySQL 业务数据完成 DeepSeek 模型的定制化训练,适用于订单查询、客服问答等场景。实际应用中需根据数据规模调整训练轮次和 LoRA 参数。