领域专用AI训练工具(离线版)

现在的AI多数是通用的,未来会有更多专用的AI

把专业领域的书籍资料扔给AI学习,让它成为该领域的专家

这是我写这段代码的初衷。

这段代码实现了一个基于Gradio的领域专用AI训练工具的Web界面,主要功能是让用户能够上传数据集并训练文本分类模型。以下是代码的主要功能和结构分析:

主要功能

  1. 模型训练功能

    • 支持上传CSV/Excel格式的数据集

    • 可选择不同的预训练模型(BERT、DistilBERT、RoBERTa)

    • 可调整学习率、批次大小和训练轮数等超参数

    • 自动跟踪和显示训练过程中的损失和准确率指标

  2. 可视化功能

    • 实时绘制训练损失和验证准确率曲线

    • 双坐标轴显示不同量纲的指标

  3. 模型测试功能

    • 提供界面输入测试文本

    • 显示预测结果(当前是模拟功能)

代码结构

  1. 导入库

    • 使用了Hugging Face的transformers库进行模型训练

    • 使用Gradio构建Web界面

    • 使用pandas和datasets处理数据

  2. 训练函数 train_model()

    • 数据预处理和标签编码

    • 加载预训练模型和分词器

    • 设置训练参数和评估指标

    • 训练过程记录指标到全局变量

  3. 可视化函数 plot_metrics()

    • 使用matplotlib绘制训练曲线

    • 支持双坐标轴显示不同指标

  4. Gradio界面构建

    • 左侧面板:数据上传、参数配置

    • 右侧面板:训练状态、指标图表、模型测试

使用说明

  1. 上传包含文本和标签列的CSV文件

  2. 指定文本列和标签列的名称

  3. 选择预训练模型和调整参数

  4. 点击"开始训练"按钮

  5. 训练完成后可以在下方测试模型

注意事项

  1. 当前测试功能是模拟的,需要替换为真实的模型预测代码

  2. 训练回调函数可能需要调整以正确记录指标

  3. 需要安装所有依赖库才能运行

  4. 界面默认运行在7860端口

这个工具适合初学者体验NLP模型训练的基本流程,可以通过修改代码来扩展更多功能。

在自然语言处理(NLP)领域,BERT、DistilBERT 和 RoBERTa 都是基于 Transformer 架构的预训练模型,但它们在模型结构、训练方式和性能上有所不同。以下是它们的核心区别:


1. BERT(BERT-base-uncased)

  • 特点

    • 双向上下文理解:通过掩码语言建模(MLM)和下一句预测(NSP)任务预训练,能同时捕捉左右上下文信息。

    • 模型规模:Base 版本有 12 层 Transformer,约 1.1 亿参数。

    • 训练数据:BooksCorpus 和英文维基百科。

  • 优势

    • 通用性强,适合大多数 NLP 任务(如文本分类、问答、NER)。

    • 开源生态完善,文档和教程丰富。

  • 缺点

    • 计算资源消耗较大(训练和推理速度较慢)。

    • NSP 任务在后继研究中被证明效果有限。


2. DistilBERT(DistilBERT-base-uncased)

  • 特点

    • 轻量化版 BERT:通过知识蒸馏(Knowledge Distillation)技术,用 BERT 作为教师模型训练的小型模型。

    • 模型规模:6 层 Transformer,参数量减少 40%(约 6600 万),但保留 97% 的 BERT 性能。

    • 训练任务:仅保留 MLM,移除了 NSP。

  • 优势

    • 推理速度快 60%,内存占用更低,适合资源受限场景(如移动端、边缘设备)。

    • 在大多数任务上性能接近 BERT。

  • 缺点

    • 极复杂任务(如需要深层语义理解)上可能略逊于原始 BERT。

    • 蒸馏过程可能丢失部分细粒度语义信息。


3. RoBERTa(RoBERTa-base)

  • 特点

    • BERT 的优化版:由 Facebook 提出,改进了 BERT 的训练策略。

    • 关键改进

      • 移除 NSP 任务,仅用 MLM。

      • 更大的批次大小(8K vs BERT 的 256)和更多数据(160GB vs BERT 的 16GB)。

      • 动态掩码(Dynamic Masking)代替静态掩码。

    • 模型规模:与 BERT-base 相同(12 层),但训练更充分。

  • 优势

    • 性能更强,在 GLUE、SQuAD 等基准上超越 BERT。

    • 对长文本和复杂任务表现更好。

  • 缺点

    • 训练成本极高(需要大量计算资源)。

    • 推理速度与 BERT 相近,仍比 DistilBERT 慢。


对比总结

特性BERTDistilBERTRoBERTa
参数量~110M~66M~110M
训练数据量16GB同 BERT160GB
训练任务MLM + NSPMLMMLM(动态掩码)
推理速度中等中等
典型应用场景通用 NLP资源受限环境高性能需求任务

选择建议

  1. 优先 RoBERTa:如果追求最高准确率且资源充足。

  2. 优先 DistilBERT:如果需要快速部署或资源有限(如 API 服务)。

  3. 选择 BERT:如果需平衡性能和兼容性,或依赖现有 BERT 生态工具。

在你的训练工具中,用户可以根据任务需求(速度 vs 精度)和硬件条件灵活选择模型。

以下是完整代码,由于是离线版,需要在根目录创建model文件夹然后将模型下载放入,下载链接:模型http://通过网盘分享的文件:model.zip 链接: https://pan.baidu.com/s/1YxzD9T4jd545FNzPwte_Sg?pwd=jfpj 提取码: jfpj --来自百度网盘超级会员v8的分享

import gradio as gr
import pandas as pd
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    TrainerCallback,
    pipeline
)
import datasets
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import os
import json
from pathlib import Path
from io import StringIO

# 全局变量存储训练状态和模型
training_history = {"loss": [], "accuracy": []}
current_model = None
current_tokenizer = None
label_map = {}
model_dir = "./model"  # 模型存放目录

class LoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            if "loss" in logs and "epoch" in logs:
                training_history["loss"].append((logs["epoch"], logs["loss"]))
            if "eval_accuracy" in logs and "epoch" in logs:
                training_history["accuracy"].append((logs["epoch"], logs["eval_accuracy"]))

def get_available_models():
    """获取本地可用的模型列表"""
    models = []
    for model_name in ["bert-base-uncased", "distilbert-base-uncased", "roberta-base", "bert-base-chinese"]:
        model_path = os.path.join(model_dir, model_name)
        if os.path.exists(model_path):
            models.append(model_name)
    return models

def load_data(data_path, text_column, label_column):
    try:
        # 读取数据文件
        if data_path.endswith('.csv'):
            df = pd.read_csv(data_path)
        elif data_path.endswith('.xlsx'):
            df = pd.read_excel(data_path)
        else:
            return None, "不支持的文件格式,请上传CSV或Excel文件"
        
        # 检查列名是否存在
        if text_column not in df.columns or label_column not in df.columns:
            return None, f"数据集中找不到指定的列: '{text_column}' 或 '{label_column}'"
        
        # 检查数据是否为空
        if df.empty:
            return None, "数据集为空"
        
        # 创建标签映射
        global label_map
        label_list = sorted(df[label_column].unique().tolist())
        label_map = {label: i for i, label in enumerate(label_list)}
        
        # 转换数据集格式
        dataset = datasets.Dataset.from_pandas(df)
        dataset = dataset.map(lambda x: {"label": label_map[x[label_column]]})
        
        return dataset, None
    except Exception as e:
        return None, f"数据加载失败: {str(e)}"

def generate_dataset_from_text(input_text, output_format="csv"):
    """从文本生成数据集"""
    try:
        # 处理输入文本
        data = []
        for line in input_text.strip().split("\n"):
            line = line.strip()
            if line and "诊断为" in line:
                # 分割文本和标签
                parts = line.split(",诊断为")
                if len(parts) == 2:
                    text = parts[0].strip()
                    label = parts[1].replace("。", "").strip()
                    data.append({"text": text, "label": label})
        
        if not data:
            return "未找到有效数据,请确保每行包含'诊断为'分隔符", None
        
        # 创建DataFrame
        df = pd.DataFrame(data)
        
        # 创建输出目录
        output_dir = "./generated_dataset"
        os.makedirs(output_dir, exist_ok=True)
        
        # 保存标签映射
        global label_map
        label_list = sorted(df["label"].unique().tolist())
        label_map = {label: i for i, label in enumerate(label_list)}
        
        with open(os.path.join(output_dir, "label_map.json"), "w") as f:
            json.dump(label_map, f)
        
        # 根据格式保存
        if output_format == "csv":
            output_path = os.path.join(output_dir, "generated_dataset.csv")
            df.to_csv(output_path, index=False, encoding="utf-8")
        elif output_format == "json":
            output_path = os.path.join(output_dir, "generated_dataset.json")
            df.to_json(output_path, orient="records", force_ascii=False)
        elif output_format == "huggingface":
            output_path = output_dir
            dataset = datasets.Dataset.from_pandas(df)
            dataset.save_to_disk(output_path)
        else:
            return "不支持的输出格式", None
        
        # 返回数据集预览
        preview_df = df.head(5)
        return f"数据集已生成并保存为{output_format.upper()}格式到: {output_path}", preview_df
    
    except Exception as e:
        return f"生成数据集失败: {str(e)}", None

def convert_to_dataset(data_path, text_column, label_column, output_format="csv"):
    """将数据转换为指定格式的数据集"""
    try:
        dataset, error_msg = load_data(data_path, text_column, label_column)
        if error_msg:
            return error_msg, None
        
        output_dir = "./converted_dataset"
        os.makedirs(output_dir, exist_ok=True)
        
        # 保存标签映射
        with open(os.path.join(output_dir, "label_map.json"), "w") as f:
            json.dump(label_map, f)
        
        # 转换为指定格式
        if output_format == "csv":
            output_path = os.path.join(output_dir, "dataset.csv")
            dataset.to_csv(output_path)
        elif output_format == "json":
            output_path = os.path.join(output_dir, "dataset.json")
            dataset.to_json(output_path)
        elif output_format == "huggingface":
            output_path = output_dir
            dataset.save_to_disk(output_path)
        else:
            return "不支持的输出格式", None
        
        # 返回数据集预览
        preview_df = dataset.to_pandas().head(5)
        return f"数据集已成功转换为{output_format.upper()}格式,保存到: {output_path}", preview_df
    except Exception as e:
        return f"数据集转换失败: {str(e)}", None

def preview_data(data_path):
    try:
        if data_path.endswith('.csv'):
            df = pd.read_csv(data_path, nrows=5)
        elif data_path.endswith('.xlsx'):
            df = pd.read_excel(data_path, nrows=5)
        else:
            return "不支持的文件格式", None
        return None, df
    except Exception as e:
        return f"数据预览失败: {str(e)}", None

def train_model(
    data_path,
    text_column,
    label_column,
    model_name,
    learning_rate,
    batch_size,
    epochs,
    output_dir
):
    global current_model, current_tokenizer, training_history, label_map
    
    try:
        # 加载数据
        dataset, error_msg = load_data(data_path, text_column, label_column)
        if error_msg:
            return error_msg
        
        # 检查模型是否存在
        model_path = os.path.join(model_dir, model_name)
        if not os.path.exists(model_path):
            return f"找不到本地模型: {model_name},请确保模型已下载到{model_dir}目录"
        
        # 加载模型和分词器
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            model = AutoModelForSequenceClassification.from_pretrained(
                 model_path, 
                 num_labels=len(label_map),
                 ignore_mismatched_sizes=True  # 忽略尺寸不匹配
            )
        except Exception as e:
            return f"模型加载失败: {str(e)}"
        
        # 数据预处理
        def tokenize_function(examples):
            return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=128)
        
        tokenized_dataset = dataset.map(tokenize_function, batched=True)
        tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.2, seed=42)
        
        # 创建输出目录
        os.makedirs(output_dir, exist_ok=True)
        
        # 保存标签映射
        with open(os.path.join(output_dir, "label_map.json"), "w") as f:
            json.dump(label_map, f)
        
        # 训练配置
        training_args = TrainingArguments(
            output_dir=output_dir,
            evaluation_strategy="epoch",
            learning_rate=float(learning_rate),
            per_device_train_batch_size=int(batch_size),
            per_device_eval_batch_size=int(batch_size),
            num_train_epochs=int(epochs),
            weight_decay=0.01,
            save_strategy="epoch",
            load_best_model_at_end=True,
            logging_steps=10,
            report_to="none"
        )
        
        # 定义评估指标
        def compute_metrics(eval_pred):
            logits, labels = eval_pred
            predictions = np.argmax(logits, axis=-1)
            return {"accuracy": accuracy_score(labels, predictions)}
        
        # 重置训练历史
        training_history = {"loss": [], "accuracy": []}
        
        # 训练
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset["train"],
            eval_dataset=tokenized_dataset["test"],
            compute_metrics=compute_metrics,
        )
        
        trainer.add_callback(LoggingCallback())
        trainer.train()
        
        # 保存模型
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        
        # 更新当前模型
        current_model = model
        current_tokenizer = tokenizer
        
        return f"训练完成!模型已保存到 {output_dir}"
    except Exception as e:
        return f"训练过程中出错: {str(e)}"

def plot_metrics():
    if not training_history["loss"]:
        return None
    
    fig, ax = plt.subplots(figsize=(10, 5))
    
    # 准备数据
    epochs = [x[0] for x in training_history["loss"]]
    loss_values = [x[1] for x in training_history["loss"]]
    
    ax.plot(epochs, loss_values, label="训练损失", marker='o')
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.grid(True)
    
    if training_history["accuracy"]:
        acc_epochs = [x[0] for x in training_history["accuracy"]]
        acc_values = [x[1] for x in training_history["accuracy"]]
        ax2 = ax.twinx()
        ax2.plot(acc_epochs, acc_values, color="red", label="验证准确率", marker='x')
        ax2.set_ylabel("Accuracy")
        ax2.set_ylim(0, 1)
        
        # 合并图例
        lines, labels = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax.legend(lines + lines2, labels + labels2, loc="upper center")
    else:
        ax.legend()
    
    plt.title("训练指标")
    plt.tight_layout()
    return fig

def predict(text, model_path=None):
    global current_model, current_tokenizer, label_map
    
    try:
        # 如果没有传入模型路径,使用当前训练的模型
        if model_path:
            if not os.path.exists(model_path):
                return {"error": "模型路径不存在"}
            
            # 加载模型和分词器
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            model = AutoModelForSequenceClassification.from_pretrained(model_path)
            
            # 加载标签映射
            label_map_path = os.path.join(model_path, "label_map.json")
            if os.path.exists(label_map_path):
                with open(label_map_path, "r") as f:
                    label_map = json.load(f)
        elif current_model and current_tokenizer:
            model = current_model
            tokenizer = current_tokenizer
        else:
            return {"error": "请先训练模型或指定模型路径"}
        
        # 准备输入
        inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
        
        # 推理
        with torch.no_grad():
            outputs = model(**inputs)
        
        # 处理输出
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=1).numpy()[0]
        
        # 反转label_map用于显示
        reverse_label_map = {v: k for k, v in label_map.items()}
        
        # 返回预测结果和置信度
        results = {
            reverse_label_map[i]: float(probabilities[i]) 
            for i in range(len(probabilities))
        }
        
        # 按置信度排序
        sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
        
        return sorted_results
    except Exception as e:
        return {"error": f"预测失败: {str(e)}"}

# 构建界面
with gr.Blocks(title="AI训练工具体验版", theme=gr.themes.Soft()) as app:
    gr.Markdown("""
    # 🚀 领域专用AI训练工具(离线版)
    *使用本地模型进行训练和推理*
    """)
    
    with gr.Tab("数据准备"):
        with gr.Row():
            with gr.Column():
                # 文件上传组件
                data_input = gr.File(label="上传数据集", file_types=[".csv", ".xlsx"])
                preview_btn = gr.Button("预览数据")
                
                # 新增文本生成数据集组件
                gr.Markdown("### 或从文本生成数据集")
                gen_text = gr.Textbox(
                    label="输入文本数据",
                    placeholder="每行一个样本,格式如: '症状描述,诊断为疾病'",
                    lines=10,
                    value="""1. 患者主诉头痛、发热,体温38.5℃,诊断为感冒。
2. 心电图显示心律不齐,诊断为心脏病。
3. 血糖水平高于正常值,诊断为糖尿病。
4. 咳嗽伴有痰液,诊断为支气管炎。
5. 血压持续升高,诊断为高血压。"""
                )
                gen_format = gr.Dropdown(
                    label="输出格式",
                    choices=["csv", "json", "huggingface"],
                    value="csv"
                )
                gen_btn = gr.Button("生成数据集")
                
                # 数据配置
                gr.Markdown("### 数据配置")
                text_col = gr.Textbox(label="文本列名", placeholder="例如: text", value="text")
                label_col = gr.Textbox(label="标签列名", placeholder="例如: label", value="label")
                
                # 数据集转换
                gr.Markdown("### 数据集转换")
                output_format = gr.Dropdown(
                    label="输出格式",
                    choices=["csv", "json", "huggingface"],
                    value="csv"
                )
                convert_btn = gr.Button("转换数据集格式")
            
            with gr.Column():
                # 数据预览
                data_preview = gr.Dataframe(label="数据预览", interactive=False)
                # 状态显示
                gen_status = gr.Textbox(label="生成状态", interactive=False)
                convert_status = gr.Textbox(label="转换状态", interactive=False)
    
    with gr.Tab("模型训练"):
        with gr.Row():
            with gr.Column():
                gr.Markdown("### 模型配置")
                model_selector = gr.Dropdown(
                    label="选择预训练模型",
                    choices=get_available_models(),
                    value=get_available_models()[0] if get_available_models() else None
                )
                
                gr.Markdown("### 训练参数")
                with gr.Row():
                    lr_slider = gr.Slider(1e-5, 1e-3, value=2e-5, label="学习率", step=1e-6)
                    batch_slider = gr.Slider(1, 32, value=8, step=1, label="批次大小")
                with gr.Row():
                    epoch_slider = gr.Slider(1, 10, value=3, step=1, label="训练轮数")
                    output_dir = gr.Textbox(label="输出目录", value="./trained_model")
                
                train_btn = gr.Button("开始训练", variant="primary")
            
            with gr.Column():
                status = gr.Textbox(label="训练状态", interactive=False)
                plot = gr.Plot(label="训练指标")
    
    with gr.Tab("模型推理"):
        with gr.Row():
            with gr.Column():
                gr.Markdown("### 使用训练好的模型")
                trained_model_path = gr.Textbox(label="模型路径", placeholder="输入训练好的模型路径")
                test_input = gr.Textbox(label="输入测试文本", placeholder="输入要分类的文本...", lines=3)
                test_btn = gr.Button("测试模型", variant="primary")
            
            with gr.Column():
                test_output = gr.Label(label="预测结果")
                examples = gr.Examples(
                    examples=["这是一个正面的评论", "我不喜欢这个产品", "质量一般般"],
                    inputs=test_input
                )
    
    # 事件绑定
    preview_btn.click(
        fn=preview_data,
        inputs=data_input,
        outputs=[convert_status, data_preview]
    )
    
    gen_btn.click(
        fn=generate_dataset_from_text,
        inputs=[gen_text, gen_format],
        outputs=[gen_status, data_preview]
    )
    
    convert_btn.click(
        fn=convert_to_dataset,
        inputs=[data_input, text_col, label_col, output_format],
        outputs=[convert_status, data_preview]
    )
    
    train_btn.click(
        fn=train_model,
        inputs=[data_input, text_col, label_col, model_selector, 
               lr_slider, batch_slider, epoch_slider, output_dir],
        outputs=status
    ).then(
        fn=plot_metrics,
        outputs=plot
    )
    
    test_btn.click(
        fn=predict,
        inputs=[test_input, trained_model_path],
        outputs=test_output
    )

if __name__ == "__main__":
    # 创建模型目录如果不存在
    Path(model_dir).mkdir(exist_ok=True)
    app.launch(server_port=7860, share=False)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值