现在的AI多数是通用的,未来会有更多专用的AI
把专业领域的书籍资料扔给AI学习,让它成为该领域的专家
这是我写这段代码的初衷。
这段代码实现了一个基于Gradio的领域专用AI训练工具的Web界面,主要功能是让用户能够上传数据集并训练文本分类模型。以下是代码的主要功能和结构分析:
主要功能
-
模型训练功能:
-
支持上传CSV/Excel格式的数据集
-
可选择不同的预训练模型(BERT、DistilBERT、RoBERTa)
-
可调整学习率、批次大小和训练轮数等超参数
-
自动跟踪和显示训练过程中的损失和准确率指标
-
-
可视化功能:
-
实时绘制训练损失和验证准确率曲线
-
双坐标轴显示不同量纲的指标
-
-
模型测试功能:
-
提供界面输入测试文本
-
显示预测结果(当前是模拟功能)
-
代码结构
-
导入库:
-
使用了Hugging Face的transformers库进行模型训练
-
使用Gradio构建Web界面
-
使用pandas和datasets处理数据
-
-
训练函数
train_model()
:-
数据预处理和标签编码
-
加载预训练模型和分词器
-
设置训练参数和评估指标
-
训练过程记录指标到全局变量
-
-
可视化函数
plot_metrics()
:-
使用matplotlib绘制训练曲线
-
支持双坐标轴显示不同指标
-
-
Gradio界面构建:
-
左侧面板:数据上传、参数配置
-
右侧面板:训练状态、指标图表、模型测试
-
使用说明
-
上传包含文本和标签列的CSV文件
-
指定文本列和标签列的名称
-
选择预训练模型和调整参数
-
点击"开始训练"按钮
-
训练完成后可以在下方测试模型
注意事项
-
当前测试功能是模拟的,需要替换为真实的模型预测代码
-
训练回调函数可能需要调整以正确记录指标
-
需要安装所有依赖库才能运行
-
界面默认运行在7860端口
这个工具适合初学者体验NLP模型训练的基本流程,可以通过修改代码来扩展更多功能。
在自然语言处理(NLP)领域,BERT、DistilBERT 和 RoBERTa 都是基于 Transformer 架构的预训练模型,但它们在模型结构、训练方式和性能上有所不同。以下是它们的核心区别:
1. BERT(BERT-base-uncased)
-
特点:
-
双向上下文理解:通过掩码语言建模(MLM)和下一句预测(NSP)任务预训练,能同时捕捉左右上下文信息。
-
模型规模:Base 版本有 12 层 Transformer,约 1.1 亿参数。
-
训练数据:BooksCorpus 和英文维基百科。
-
-
优势:
-
通用性强,适合大多数 NLP 任务(如文本分类、问答、NER)。
-
开源生态完善,文档和教程丰富。
-
-
缺点:
-
计算资源消耗较大(训练和推理速度较慢)。
-
NSP 任务在后继研究中被证明效果有限。
-
2. DistilBERT(DistilBERT-base-uncased)
-
特点:
-
轻量化版 BERT:通过知识蒸馏(Knowledge Distillation)技术,用 BERT 作为教师模型训练的小型模型。
-
模型规模:6 层 Transformer,参数量减少 40%(约 6600 万),但保留 97% 的 BERT 性能。
-
训练任务:仅保留 MLM,移除了 NSP。
-
-
优势:
-
推理速度快 60%,内存占用更低,适合资源受限场景(如移动端、边缘设备)。
-
在大多数任务上性能接近 BERT。
-
-
缺点:
-
极复杂任务(如需要深层语义理解)上可能略逊于原始 BERT。
-
蒸馏过程可能丢失部分细粒度语义信息。
-
3. RoBERTa(RoBERTa-base)
-
特点:
-
BERT 的优化版:由 Facebook 提出,改进了 BERT 的训练策略。
-
关键改进:
-
移除 NSP 任务,仅用 MLM。
-
更大的批次大小(8K vs BERT 的 256)和更多数据(160GB vs BERT 的 16GB)。
-
动态掩码(Dynamic Masking)代替静态掩码。
-
-
模型规模:与 BERT-base 相同(12 层),但训练更充分。
-
-
优势:
-
性能更强,在 GLUE、SQuAD 等基准上超越 BERT。
-
对长文本和复杂任务表现更好。
-
-
缺点:
-
训练成本极高(需要大量计算资源)。
-
推理速度与 BERT 相近,仍比 DistilBERT 慢。
-
对比总结
特性 | BERT | DistilBERT | RoBERTa |
---|---|---|---|
参数量 | ~110M | ~66M | ~110M |
训练数据量 | 16GB | 同 BERT | 160GB |
训练任务 | MLM + NSP | MLM | MLM(动态掩码) |
推理速度 | 中等 | 快 | 中等 |
典型应用场景 | 通用 NLP | 资源受限环境 | 高性能需求任务 |
选择建议
-
优先 RoBERTa:如果追求最高准确率且资源充足。
-
优先 DistilBERT:如果需要快速部署或资源有限(如 API 服务)。
-
选择 BERT:如果需平衡性能和兼容性,或依赖现有 BERT 生态工具。
在你的训练工具中,用户可以根据任务需求(速度 vs 精度)和硬件条件灵活选择模型。
以下是完整代码,由于是离线版,需要在根目录创建model文件夹然后将模型下载放入,下载链接:模型http://通过网盘分享的文件:model.zip 链接: https://pan.baidu.com/s/1YxzD9T4jd545FNzPwte_Sg?pwd=jfpj 提取码: jfpj --来自百度网盘超级会员v8的分享
import gradio as gr
import pandas as pd
import numpy as np
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
TrainerCallback,
pipeline
)
import datasets
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import os
import json
from pathlib import Path
from io import StringIO
# 全局变量存储训练状态和模型
training_history = {"loss": [], "accuracy": []}
current_model = None
current_tokenizer = None
label_map = {}
model_dir = "./model" # 模型存放目录
class LoggingCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
if "loss" in logs and "epoch" in logs:
training_history["loss"].append((logs["epoch"], logs["loss"]))
if "eval_accuracy" in logs and "epoch" in logs:
training_history["accuracy"].append((logs["epoch"], logs["eval_accuracy"]))
def get_available_models():
"""获取本地可用的模型列表"""
models = []
for model_name in ["bert-base-uncased", "distilbert-base-uncased", "roberta-base", "bert-base-chinese"]:
model_path = os.path.join(model_dir, model_name)
if os.path.exists(model_path):
models.append(model_name)
return models
def load_data(data_path, text_column, label_column):
try:
# 读取数据文件
if data_path.endswith('.csv'):
df = pd.read_csv(data_path)
elif data_path.endswith('.xlsx'):
df = pd.read_excel(data_path)
else:
return None, "不支持的文件格式,请上传CSV或Excel文件"
# 检查列名是否存在
if text_column not in df.columns or label_column not in df.columns:
return None, f"数据集中找不到指定的列: '{text_column}' 或 '{label_column}'"
# 检查数据是否为空
if df.empty:
return None, "数据集为空"
# 创建标签映射
global label_map
label_list = sorted(df[label_column].unique().tolist())
label_map = {label: i for i, label in enumerate(label_list)}
# 转换数据集格式
dataset = datasets.Dataset.from_pandas(df)
dataset = dataset.map(lambda x: {"label": label_map[x[label_column]]})
return dataset, None
except Exception as e:
return None, f"数据加载失败: {str(e)}"
def generate_dataset_from_text(input_text, output_format="csv"):
"""从文本生成数据集"""
try:
# 处理输入文本
data = []
for line in input_text.strip().split("\n"):
line = line.strip()
if line and "诊断为" in line:
# 分割文本和标签
parts = line.split(",诊断为")
if len(parts) == 2:
text = parts[0].strip()
label = parts[1].replace("。", "").strip()
data.append({"text": text, "label": label})
if not data:
return "未找到有效数据,请确保每行包含'诊断为'分隔符", None
# 创建DataFrame
df = pd.DataFrame(data)
# 创建输出目录
output_dir = "./generated_dataset"
os.makedirs(output_dir, exist_ok=True)
# 保存标签映射
global label_map
label_list = sorted(df["label"].unique().tolist())
label_map = {label: i for i, label in enumerate(label_list)}
with open(os.path.join(output_dir, "label_map.json"), "w") as f:
json.dump(label_map, f)
# 根据格式保存
if output_format == "csv":
output_path = os.path.join(output_dir, "generated_dataset.csv")
df.to_csv(output_path, index=False, encoding="utf-8")
elif output_format == "json":
output_path = os.path.join(output_dir, "generated_dataset.json")
df.to_json(output_path, orient="records", force_ascii=False)
elif output_format == "huggingface":
output_path = output_dir
dataset = datasets.Dataset.from_pandas(df)
dataset.save_to_disk(output_path)
else:
return "不支持的输出格式", None
# 返回数据集预览
preview_df = df.head(5)
return f"数据集已生成并保存为{output_format.upper()}格式到: {output_path}", preview_df
except Exception as e:
return f"生成数据集失败: {str(e)}", None
def convert_to_dataset(data_path, text_column, label_column, output_format="csv"):
"""将数据转换为指定格式的数据集"""
try:
dataset, error_msg = load_data(data_path, text_column, label_column)
if error_msg:
return error_msg, None
output_dir = "./converted_dataset"
os.makedirs(output_dir, exist_ok=True)
# 保存标签映射
with open(os.path.join(output_dir, "label_map.json"), "w") as f:
json.dump(label_map, f)
# 转换为指定格式
if output_format == "csv":
output_path = os.path.join(output_dir, "dataset.csv")
dataset.to_csv(output_path)
elif output_format == "json":
output_path = os.path.join(output_dir, "dataset.json")
dataset.to_json(output_path)
elif output_format == "huggingface":
output_path = output_dir
dataset.save_to_disk(output_path)
else:
return "不支持的输出格式", None
# 返回数据集预览
preview_df = dataset.to_pandas().head(5)
return f"数据集已成功转换为{output_format.upper()}格式,保存到: {output_path}", preview_df
except Exception as e:
return f"数据集转换失败: {str(e)}", None
def preview_data(data_path):
try:
if data_path.endswith('.csv'):
df = pd.read_csv(data_path, nrows=5)
elif data_path.endswith('.xlsx'):
df = pd.read_excel(data_path, nrows=5)
else:
return "不支持的文件格式", None
return None, df
except Exception as e:
return f"数据预览失败: {str(e)}", None
def train_model(
data_path,
text_column,
label_column,
model_name,
learning_rate,
batch_size,
epochs,
output_dir
):
global current_model, current_tokenizer, training_history, label_map
try:
# 加载数据
dataset, error_msg = load_data(data_path, text_column, label_column)
if error_msg:
return error_msg
# 检查模型是否存在
model_path = os.path.join(model_dir, model_name)
if not os.path.exists(model_path):
return f"找不到本地模型: {model_name},请确保模型已下载到{model_dir}目录"
# 加载模型和分词器
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(
model_path,
num_labels=len(label_map),
ignore_mismatched_sizes=True # 忽略尺寸不匹配
)
except Exception as e:
return f"模型加载失败: {str(e)}"
# 数据预处理
def tokenize_function(examples):
return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=128)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.2, seed=42)
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 保存标签映射
with open(os.path.join(output_dir, "label_map.json"), "w") as f:
json.dump(label_map, f)
# 训练配置
training_args = TrainingArguments(
output_dir=output_dir,
evaluation_strategy="epoch",
learning_rate=float(learning_rate),
per_device_train_batch_size=int(batch_size),
per_device_eval_batch_size=int(batch_size),
num_train_epochs=int(epochs),
weight_decay=0.01,
save_strategy="epoch",
load_best_model_at_end=True,
logging_steps=10,
report_to="none"
)
# 定义评估指标
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return {"accuracy": accuracy_score(labels, predictions)}
# 重置训练历史
training_history = {"loss": [], "accuracy": []}
# 训练
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
compute_metrics=compute_metrics,
)
trainer.add_callback(LoggingCallback())
trainer.train()
# 保存模型
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# 更新当前模型
current_model = model
current_tokenizer = tokenizer
return f"训练完成!模型已保存到 {output_dir}"
except Exception as e:
return f"训练过程中出错: {str(e)}"
def plot_metrics():
if not training_history["loss"]:
return None
fig, ax = plt.subplots(figsize=(10, 5))
# 准备数据
epochs = [x[0] for x in training_history["loss"]]
loss_values = [x[1] for x in training_history["loss"]]
ax.plot(epochs, loss_values, label="训练损失", marker='o')
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.grid(True)
if training_history["accuracy"]:
acc_epochs = [x[0] for x in training_history["accuracy"]]
acc_values = [x[1] for x in training_history["accuracy"]]
ax2 = ax.twinx()
ax2.plot(acc_epochs, acc_values, color="red", label="验证准确率", marker='x')
ax2.set_ylabel("Accuracy")
ax2.set_ylim(0, 1)
# 合并图例
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2, loc="upper center")
else:
ax.legend()
plt.title("训练指标")
plt.tight_layout()
return fig
def predict(text, model_path=None):
global current_model, current_tokenizer, label_map
try:
# 如果没有传入模型路径,使用当前训练的模型
if model_path:
if not os.path.exists(model_path):
return {"error": "模型路径不存在"}
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
# 加载标签映射
label_map_path = os.path.join(model_path, "label_map.json")
if os.path.exists(label_map_path):
with open(label_map_path, "r") as f:
label_map = json.load(f)
elif current_model and current_tokenizer:
model = current_model
tokenizer = current_tokenizer
else:
return {"error": "请先训练模型或指定模型路径"}
# 准备输入
inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
# 推理
with torch.no_grad():
outputs = model(**inputs)
# 处理输出
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1).numpy()[0]
# 反转label_map用于显示
reverse_label_map = {v: k for k, v in label_map.items()}
# 返回预测结果和置信度
results = {
reverse_label_map[i]: float(probabilities[i])
for i in range(len(probabilities))
}
# 按置信度排序
sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
return sorted_results
except Exception as e:
return {"error": f"预测失败: {str(e)}"}
# 构建界面
with gr.Blocks(title="AI训练工具体验版", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# 🚀 领域专用AI训练工具(离线版)
*使用本地模型进行训练和推理*
""")
with gr.Tab("数据准备"):
with gr.Row():
with gr.Column():
# 文件上传组件
data_input = gr.File(label="上传数据集", file_types=[".csv", ".xlsx"])
preview_btn = gr.Button("预览数据")
# 新增文本生成数据集组件
gr.Markdown("### 或从文本生成数据集")
gen_text = gr.Textbox(
label="输入文本数据",
placeholder="每行一个样本,格式如: '症状描述,诊断为疾病'",
lines=10,
value="""1. 患者主诉头痛、发热,体温38.5℃,诊断为感冒。
2. 心电图显示心律不齐,诊断为心脏病。
3. 血糖水平高于正常值,诊断为糖尿病。
4. 咳嗽伴有痰液,诊断为支气管炎。
5. 血压持续升高,诊断为高血压。"""
)
gen_format = gr.Dropdown(
label="输出格式",
choices=["csv", "json", "huggingface"],
value="csv"
)
gen_btn = gr.Button("生成数据集")
# 数据配置
gr.Markdown("### 数据配置")
text_col = gr.Textbox(label="文本列名", placeholder="例如: text", value="text")
label_col = gr.Textbox(label="标签列名", placeholder="例如: label", value="label")
# 数据集转换
gr.Markdown("### 数据集转换")
output_format = gr.Dropdown(
label="输出格式",
choices=["csv", "json", "huggingface"],
value="csv"
)
convert_btn = gr.Button("转换数据集格式")
with gr.Column():
# 数据预览
data_preview = gr.Dataframe(label="数据预览", interactive=False)
# 状态显示
gen_status = gr.Textbox(label="生成状态", interactive=False)
convert_status = gr.Textbox(label="转换状态", interactive=False)
with gr.Tab("模型训练"):
with gr.Row():
with gr.Column():
gr.Markdown("### 模型配置")
model_selector = gr.Dropdown(
label="选择预训练模型",
choices=get_available_models(),
value=get_available_models()[0] if get_available_models() else None
)
gr.Markdown("### 训练参数")
with gr.Row():
lr_slider = gr.Slider(1e-5, 1e-3, value=2e-5, label="学习率", step=1e-6)
batch_slider = gr.Slider(1, 32, value=8, step=1, label="批次大小")
with gr.Row():
epoch_slider = gr.Slider(1, 10, value=3, step=1, label="训练轮数")
output_dir = gr.Textbox(label="输出目录", value="./trained_model")
train_btn = gr.Button("开始训练", variant="primary")
with gr.Column():
status = gr.Textbox(label="训练状态", interactive=False)
plot = gr.Plot(label="训练指标")
with gr.Tab("模型推理"):
with gr.Row():
with gr.Column():
gr.Markdown("### 使用训练好的模型")
trained_model_path = gr.Textbox(label="模型路径", placeholder="输入训练好的模型路径")
test_input = gr.Textbox(label="输入测试文本", placeholder="输入要分类的文本...", lines=3)
test_btn = gr.Button("测试模型", variant="primary")
with gr.Column():
test_output = gr.Label(label="预测结果")
examples = gr.Examples(
examples=["这是一个正面的评论", "我不喜欢这个产品", "质量一般般"],
inputs=test_input
)
# 事件绑定
preview_btn.click(
fn=preview_data,
inputs=data_input,
outputs=[convert_status, data_preview]
)
gen_btn.click(
fn=generate_dataset_from_text,
inputs=[gen_text, gen_format],
outputs=[gen_status, data_preview]
)
convert_btn.click(
fn=convert_to_dataset,
inputs=[data_input, text_col, label_col, output_format],
outputs=[convert_status, data_preview]
)
train_btn.click(
fn=train_model,
inputs=[data_input, text_col, label_col, model_selector,
lr_slider, batch_slider, epoch_slider, output_dir],
outputs=status
).then(
fn=plot_metrics,
outputs=plot
)
test_btn.click(
fn=predict,
inputs=[test_input, trained_model_path],
outputs=test_output
)
if __name__ == "__main__":
# 创建模型目录如果不存在
Path(model_dir).mkdir(exist_ok=True)
app.launch(server_port=7860, share=False)