聊天机器人羲和的代码01

已经完成了 xihua_model.pth 的训练,我们现在只需要编写一个使用该模型的聊天机器人羲和的代码。以下是完整的代码,包括加载模型、创建GUI界面以及处理用户输入和生成回答的功能。

聊天机器人羲和的代码
python

import os
import json
import jsonlines
import torch
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 数据集类(用于加载数据)
class XihuaDataset:
    def __init__(self, file_path):
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for item in reader:
                    data.append(item)
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                data = json.load(f)
        return data

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        self.create_widgets()

    def create_widgets(self):
        self.question_label = tk.Label(self.root, text="问题:")
        self.question_label.pack()

        self.question_entry = tk.Entry(self.root, width=50)
        self.question_entry.pack()

        self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)
        self.answer_button.pack()

        self.answer_label = tk.Label(self.root, text="回答:")
        self.answer_label.pack()

        self.answer_text = tk.Text(self.root, height=10, width=50)
        self.answer_text.pack()

    def get_answer(self):
        question = self.question_entry.get()
        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "human_answers"
        else:
            answer_type = "chatgpt_answers"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.delete(1.0, tk.END)
        self.answer_text.insert(tk.END, specific_answer)

    def get_specific_answer(self, question, answer_type):
        # 从数据集中查找具体的答案
        for item in self.data:
            if item['question'] == question:
                return item[answer_type][0]
        return "未找到具体答案"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for item in reader:
                    data.append(item)
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                data = json.load(f)
        return data

# 主函数
if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

说明

数据集类:XihuaDataset 类用于加载训练数据集。
模型定义:XihuaModel 类定义了模型的结构。
GUI界面:XihuaChatbotGUI 类创建了聊天机器人的图形用户界面,并实现了获取回答的功能。
主函数:启动GUI界面。
运行步骤
确保项目目录结构:

code

project_root/
│
├── data/
│   └── train_data.jsonl
│
├── models/
│   └── xihua_model.pth
│
└── xihua_chatbot.py

运行聊天机器人:

sh

python xihua_chatbot.py
注意事项

BERT预训练模型路径:确保 F:/models/bert-base-chinese 路径正确,并且包含 BERT 预训练模型文件。
训练数据路径:确保 data/train_data.jsonl 文件存在并且格式正确。
通过以上步骤,你应该能够成功运行聊天机器人羲和,并使用已经训练好的模型来生成回答。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yehaiwz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值