训练DeepSeek模型
- 机器配置:A100 显存40G,内存96G,硬盘160G(有实力硬盘可以加大点,默认注意清理缓存)
- QLoRA 参数(适配 A100):
TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
learning_rate=2e-4,
num_train_epochs=3,
fp16=True,
save_strategy="epoch",
logging_steps=10,
report_to="none"
)
- 空间建议:
- 模型文件(fp16):约 13GB(base),+ 微调权重 ≈ 1~2GB
- 临时缓存(HF cache、训练 checkpoint):4080GB
- 数据集(如果使用 JSONL 或 txt):通常 1~10GB
- 建议清理 Hugging Face 缓存和 checkpoint 文件夹,以避免硬盘被撑爆。
- 完整训练写作代码:
!pip install -q accelerate transformers datasets peft bitsandbytes einops xformers onnx onnxruntime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from datasets import load_dataset, Dataset
from transformers import DataCollatorForLanguageModeling
from transformers import OnnxConfig, OnnxModel
torch.cuda.empty_cache()
model_name = "deepseek-ai/deepseek-llm-7b-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True,
device_map="auto",
trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, lora_config)
train_dataset = load_dataset("json", data_files={"train": "./train.jsonl"}, split="train")
def tokenize_function(example):
return tokenizer(example["text"], truncation=True, max_length=512)
tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
training_args = TrainingArguments(
output_dir="./qlora-output",
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
learning_rate=2e-4,
num_train_epochs=3,
fp16=True,
logging_steps=10,
save_strategy="epoch",
report_to="none"
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator
)
trainer.train()
model.save_pretrained("qlora-deepseek-writer")
tokenizer.save_pretrained("qlora-deepseek-writer")
dummy_input = torch.ones(1, 512, dtype=torch.long).to(model.device)
onnx_output_path = "qlora-deepseek-writer.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_output_path,
input_names=["input_ids"],
output_names=["logits"],
dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}, "logits": {0: "batch_size", 1: "sequence_length"}},
do_constant_folding=True,
opset_version=13
)
print(f"✅ 模型已成功保存为 ONNX 格式,路径:{onnx_output_path}")
- Java 中加载 ONNX 模型以及使用示范
- 步骤 1: 导入依赖
<dependency>
<groupId>ai.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.14.1</version> <!-- 请确保使用的是最新版本 -->
</dependency>
import ai.onnxruntime.*;
import java.nio.LongBuffer;
import java.util.Collections;
public class OnnxModelInference {
public static void main(String[] args) throws Exception {
String modelPath = "path_to_your_model/qlora-deepseek-writer.onnx";
try (OnnxRuntimeEnvironment env = OnnxRuntimeEnvironment.createEnvironment()) {
OnnxModel model = env.loadModel(modelPath);
String prompt = "请生成一篇关于人工智能的文章,并提供标题和关键词。";
long[] inputIds = tokenizer.encode(prompt);
OnnxTensor inputTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds), new long[]{1, inputIds.length});
try (OnnxValue result = model.runInference(Collections.singletonMap("input_ids", inputTensor))) {
long[] outputIds = (long[]) result.getValue(0);
String generatedText = tokenizer.decode(outputIds);
System.out.println("生成的文章内容: " + generatedText);
String title = extractTitle(generatedText);
String keywords = extractKeywords(generatedText);
System.out.println("标题: " + title);
System.out.println("关键词: " + keywords);
}
}
}
private static String extractTitle(String text) {
String[] sentences = text.split("\\. ");
return sentences.length > 0 ? sentences[0] : "默认标题";
}
private static String extractKeywords(String text) {
String[] words = text.split(" ");
if (words.length > 5) {
return String.join(", ", words[0], words[1], words[2], words[3], words[4]);
}
return String.join(", ", words);
}
}