import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import random
import tkinter as tk
import jieba
import matplotlib.pyplot as plt
import os
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.cuda.amp import GradScaler, autocast
from nltk.translate.bleu_score import corpus_bleu
from rouge import Rouge
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"
word2index = {
PAD_TOKEN: 0, UNK_TOKEN: 1, SOS_TOKEN: 2, EOS_TOKEN: 3}
index2word = {
0: PAD_TOKEN, 1: UNK_TOKEN, 2: SOS_TOKEN, 3: EOS_TOKEN}
def tokenize_chinese(sentence):
tokens = jieba.lcut(sentence)
return tokens
def build_vocab(sentences):
global word2index, index2word
vocab_size = len(word2index)
for sentence in sentences:
for token in tokenize_chinese(sentence):
if token not in word2index:
word2index[token] = vocab_size
index2word[vocab_size] = token
vocab_size += 1
return vocab_size
def sentence_to_tensor(sentence, max_length=50):
tokens = tokenize_chinese(sentence)
indices = [word2index.get(token, word2index[UNK_TOKEN]) for token in tokens]
indices = [word2index[SOS_TOKEN]] + indices + [word2index[EOS_TOKEN]]
indices += [word2index[PAD_TOKEN]] * (max_length - len(indices))
return torch.tensor(indices, dtype=torch.long), len(indices)
def load_data(file_path):
if file_path.endswith('.jsonl'):
with open(file_path, 'r', encoding='utf-8') as f:
lines = [json.loads(line) for line in f.readlines()]
elif file_path.endswith('.json'):
with open(file_path, 'r', encoding='utf-8') as f:
lines = json.load(f)
else:
raise ValueError("不支持的文件格式。请使用 .jsonl 或 .json。")
questions = [line['question'] for line in lines]
answers = [random.choice(line['human_answers'] + line['chatgpt_answers']) for line in lines]
return questions, answers
def data_augmentation(sentence):
tokens = tokenize_chinese(sentence)
augmented_sentence = []
if random.random() < 0.1:
insert_token = random.choice(list(word2index.keys())[4:])
insert_index = random.randint(0, len(tokens))
tokens.insert(insert_index, insert_token)
if random.random() < 0.1 and len(tokens) > 1:
delete_index = random.randint(0, len(tokens) - 1)
del tokens[delete_index]
if len(tokens) > 1 and random.random() < 0.1:
index1, index2 = random.sample(range(len(tokens)), 2)
tokens[index1], tokens[index2] = tokens[index2], tokens[index1]
if random.random() < 0.1:
for i in range(len(tokens)):
if random.random() < 0.1:
synonyms = get_synonyms(tokens[i])
if synonyms:
tokens[i] = random.choice(synonyms)
if random.random() < 0.1:
tokens = rewrite_sentence(tokens)
augmented_sentence = ''.join(tokens)
return augmented_sentence
def get_synonyms(word):
return []
def rewrite_sentence(tokens):
return tokens
class ChatDataset(Dataset):
def __init__(self, questions, answers):
self.questions = questions
self.answers = answers
def __len__(self):
return len(self.questions)
def __getitem__(self, idx):
input_tensor, input_length = sentence_to_tensor(self.questions[idx])
target_tensor, target_length = sentence_to_tensor(self.answers[idx])
return input_tensor, target_tensor, input_length, target_length
def collate_fn(batch):
inputs, targets, input_lengths, target_lengths = zip(*batch)
inputs = nn.utils.rnn.