为了进一步完善代码,使其在模型训练时可以加载现有模型继续训练,也可以训练新的模型并替代保存,同时强化数据处理和GUI界面,我们对12进行了以下是改进后的代码:
1.模型训练选择:增加一个选项,让用户选择是继续训练现有模型还是从头开始训练新模型。
2,数据处理强化:增加数据清洗步骤,确保数据的有效性和一致性。
3.GUI界面增强:增加一个选项框,让用户选择训练模式(继续训练或从头训练),并优化用户交互体验。
完整代码:
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox
import logging
from difflib import SequenceMatcher
from datetime import datetime
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)
def setup_logging():
log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))
os.makedirs(os.path.dirname(log_file), exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
setup_logging()
# 数据集类
class XihuaDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
if self.validate_item(item):
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {
i + 1}: {
e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = [item for item in json.load(f) if self.validate_item(item)]
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {
file_path}: {
e}")
return data
def validate_item(self, item):
required_keys = ['question', 'human_answers', 'chatgpt_answers']
if all(key in item for key in required_keys):
return True
logging.warning(f"跳过无效项: 缺少必要键 {
required_keys}")
return False
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item['question']
human_answer = item['human_answers'][0]
chatgpt_answer = item['chatgpt_answers'][0]
try:
inputs = self.tokenizer(question, return_tensors='pt', padding