【DeepSeek】训练模型以及JAVA调用实际使用示范

训练DeepSeek模型

  • 机器配置:A100 显存40G,内存96G,硬盘160G(有实力硬盘可以加大点,默认注意清理缓存)
  • QLoRA 参数(适配 A100):
TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    num_train_epochs=3,
    fp16=True,  # 或 bf16=True,如果系统支持
    save_strategy="epoch",
    logging_steps=10,
    report_to="none"
)
  • 空间建议:
    • 模型文件(fp16):约 13GB(base),+ 微调权重 ≈ 1~2GB
    • 临时缓存(HF cache、训练 checkpoint):4080GB
    • 数据集(如果使用 JSONL 或 txt):通常 1~10GB
    • 建议清理 Hugging Face 缓存和 checkpoint 文件夹,以避免硬盘被撑爆。
  • 完整训练写作代码:
# 安装必要的依赖
!pip install -q accelerate transformers datasets peft bitsandbytes einops xformers onnx onnxruntime

# 导入所需库
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from datasets import load_dataset, Dataset
from transformers import DataCollatorForLanguageModeling
from transformers import OnnxConfig, OnnxModel

# 设定 GPU 使用和显存限制(A100)
torch.cuda.empty_cache()

# ✅ 加载模型和 tokenizer
model_name = "deepseek-ai/deepseek-llm-7b-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map="auto",
    trust_remote_code=True,
)

# ✅ 准备模型进行 QLoRA 微调
model = prepare_model_for_kbit_training(model)

# QLoRA 配置
lora_config = LoraConfig(
    r=8,                     # 调整低秩矩阵秩
    lora_alpha=32,           # 设置低秩矩阵的缩放因子
    target_modules=["q_proj", "v_proj"],  # 设置低秩适配应用的模块
    lora_dropout=0.05,       # 设置低秩矩阵的 dropout
    bias="none",             # 设置是否使用偏置
    task_type=TaskType.CAUSAL_LM   # 设置任务类型为因果语言建模(Causal LM)
)

# 将 QLoRA 配置应用到模型
model = get_peft_model(model, lora_config)

# ✅ 加载数据集
# 假设数据集在根目录下的 train.jsonl 文件
train_dataset = load_dataset("json", data_files={"train": "./train.jsonl"}, split="train")

# ✅ 数据处理:tokenization
def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, max_length=512)

tokenized_dataset = train_dataset.map(tokenize_function, batched=True)

# ✅ 设置训练参数
training_args = TrainingArguments(
    output_dir="./qlora-output",               # 保存模型的输出目录
    per_device_train_batch_size=4,             # 设置批大小
    gradient_accumulation_steps=2,             # 累积梯度,减少显存占用
    learning_rate=2e-4,                        # 学习率
    num_train_epochs=3,                        # 训练的轮数
    fp16=True,                                 # 使用混合精度
    logging_steps=10,                          # 每 10 步记录一次日志
    save_strategy="epoch",                     # 每个 epoch 保存一次模型
    report_to="none"                           # 关闭日志报告到远程
)

# ✅ 数据收集器
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# ✅ 初始化训练器
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

# ✅ 开始训练
trainer.train()

# ✅ 保存训练好的原始模型(PyTorch格式)
model.save_pretrained("qlora-deepseek-writer")
tokenizer.save_pretrained("qlora-deepseek-writer")

# ✅ 转换为 ONNX 格式并保存
dummy_input = torch.ones(1, 512, dtype=torch.long).to(model.device)  # 假设最大长度是 512,输入长度根据需要调整

# 将模型转为 ONNX 格式 这里的作用是用于java后续调用模型 不用可以去掉
onnx_output_path = "qlora-deepseek-writer.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_output_path,
    input_names=["input_ids"],
    output_names=["logits"],
    dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}, "logits": {0: "batch_size", 1: "sequence_length"}},
    do_constant_folding=True,  # 开启常量折叠,优化推理速度
    opset_version=13  # 设置 ONNX opset 版本(根据需要调整)
)

print(f"✅ 模型已成功保存为 ONNX 格式,路径:{onnx_output_path}")
  • Java 中加载 ONNX 模型以及使用示范
  • 步骤 1: 导入依赖
<dependency>
    <groupId>ai.onnxruntime</groupId>
    <artifactId>onnxruntime</artifactId>
    <version>1.14.1</version> <!-- 请确保使用的是最新版本 -->
</dependency>
  • 步骤 2: 加载 ONNX 模型并进行推理
import ai.onnxruntime.*;

import java.nio.LongBuffer;
import java.util.Collections;

public class OnnxModelInference {

    public static void main(String[] args) throws Exception {
        // 加载 ONNX 模型
        String modelPath = "path_to_your_model/qlora-deepseek-writer.onnx";
        try (OnnxRuntimeEnvironment env = OnnxRuntimeEnvironment.createEnvironment()) {
            OnnxModel model = env.loadModel(modelPath);

            // 输入数据: 例如输入一个文章的开头(可以是文本数据)
            String prompt = "请生成一篇关于人工智能的文章,并提供标题和关键词。";
            long[] inputIds = tokenizer.encode(prompt);  // 使用你训练时的 tokenizer 进行编码
            OnnxTensor inputTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds), new long[]{1, inputIds.length});

            // 推理: 调用模型进行推理,输出结果是生成的文章内容
            try (OnnxValue result = model.runInference(Collections.singletonMap("input_ids", inputTensor))) {
                long[] outputIds = (long[]) result.getValue(0); // 假设模型输出的是 token ID

                // 解码输出为文本
                String generatedText = tokenizer.decode(outputIds);  // 使用训练时的 tokenizer 解码
                System.out.println("生成的文章内容: " + generatedText);

                // 提取标题和关键词(这里可以通过简单的规则进行提取,或使用预训练模型)
                String title = extractTitle(generatedText);
                String keywords = extractKeywords(generatedText);
                System.out.println("标题: " + title);
                System.out.println("关键词: " + keywords);
            }
        }
    }

    // 提取标题的简单示例方法(你可以根据需求更复杂化)
    private static String extractTitle(String text) {
        String[] sentences = text.split("\\. "); // 按句子分割
        return sentences.length > 0 ? sentences[0] : "默认标题"; // 使用第一句话作为标题
    }

    // 提取关键词的简单示例方法
    private static String extractKeywords(String text) {
        String[] words = text.split(" ");
        if (words.length > 5) {
            return String.join(", ", words[0], words[1], words[2], words[3], words[4]);
        }
        return String.join(", ", words); // 简单返回前五个单词作为关键词
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

开源福利

你的鼓励将是我创作的最大努力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值