已经完成了 xihua_model.pth 的训练,我们现在只需要编写一个使用该模型的聊天机器人羲和的代码。以下是完整的代码,包括加载模型、创建GUI界面以及处理用户输入和生成回答的功能。
聊天机器人羲和的代码
python
import os
import json
import jsonlines
import torch
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
# 数据集类(用于加载数据)
class XihuaDataset:
def __init__(self, file_path):
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for item in reader:
data.append(item)
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
data = json.load(f)
return data
# 模型定义
class XihuaModel(torch.nn.Module):
def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
super(XihuaModel, self).__init__()
self.bert = BertModel.from_pretrained(pretrained_model_name)
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
# GUI界面
class XihuaChatbotGUI:
def __init__(self, root):
self.root = root
self.root.title("羲和聊天机器人")
self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
self.model.eval()
# 加载训练数据集以便在获取答案时使用
self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))
self.create_widgets()
def create_widgets(self):
self.question_label = tk.Label(self.root, text="问题:")
self.question_label.pack()
self.question_entry = tk.Entry(self.root, width=50)
self.question_entry.pack()
self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)
self.answer_button.pack()
self.answer_label = tk.Label(self.root, text="回答:")
self.answer_label.pack()
self.answer_text = tk.Text(self.root, height=10, width=50)
self.answer_text.pack()
def get_answer(self):
question = self.question_entry.get()
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
with torch.no_grad():
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
logits = self.model(input_ids, attention_mask)
if logits.item() > 0:
answer_type = "human_answers"
else:
answer_type = "chatgpt_answers"
specific_answer = self.get_specific_answer(question, answer_type)
self.answer_text.delete(1.0, tk.END)
self.answer_text.insert(tk.END, specific_answer)
def get_specific_answer(self, question, answer_type):
# 从数据集中查找具体的答案
for item in self.data:
if item['question'] == question:
return item[answer_type][0]
return "未找到具体答案"
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for item in reader:
data.append(item)
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
data = json.load(f)
return data
# 主函数
if __name__ == "__main__":
# 启动GUI
root = tk.Tk()
app = XihuaChatbotGUI(root)
root.mainloop()
说明
数据集类:XihuaDataset 类用于加载训练数据集。
模型定义:XihuaModel 类定义了模型的结构。
GUI界面:XihuaChatbotGUI 类创建了聊天机器人的图形用户界面,并实现了获取回答的功能。
主函数:启动GUI界面。
运行步骤
确保项目目录结构:
code
project_root/
│
├── data/
│ └── train_data.jsonl
│
├── models/
│ └── xihua_model.pth
│
└── xihua_chatbot.py
运行聊天机器人:
sh
python xihua_chatbot.py
注意事项
BERT预训练模型路径:确保 F:/models/bert-base-chinese 路径正确,并且包含 BERT 预训练模型文件。
训练数据路径:确保 data/train_data.jsonl 文件存在并且格式正确。
通过以上步骤,你应该能够成功运行聊天机器人羲和,并使用已经训练好的模型来生成回答。