一个基于Transformer模型的中文问答系统926.2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import random
import tkinter as tk
import jieba
import matplotlib.pyplot as plt
import os
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.cuda.amp import GradScaler, autocast
from nltk.translate.bleu_score import corpus_bleu
from rouge import Rouge

# 特殊标记
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"

# 中文词汇表和索引映射
word2index = {
   PAD_TOKEN: 0, UNK_TOKEN: 1, SOS_TOKEN: 2, EOS_TOKEN: 3}
index2word = {
   0: PAD_TOKEN, 1: UNK_TOKEN, 2: SOS_TOKEN, 3: EOS_TOKEN}

# 使用 jieba 进行中文分词
def tokenize_chinese(sentence):
    tokens = jieba.lcut(sentence)
    return tokens

# 构建词汇表
def build_vocab(sentences):
    global word2index, index2word
    vocab_size = len(word2index)
    for sentence in sentences:
        for token in tokenize_chinese(sentence):
            if token not in word2index:
                word2index[token] = vocab_size
                index2word[vocab_size] = token
                vocab_size += 1
    return vocab_size

# 将句子转换为张量
def sentence_to_tensor(sentence, max_length=50):
    tokens = tokenize_chinese(sentence)
    indices = [word2index.get(token, word2index[UNK_TOKEN]) for token in tokens]
    indices = [word2index[SOS_TOKEN]] + indices + [word2index[EOS_TOKEN]]
    indices += [word2index[PAD_TOKEN]] * (max_length - len(indices))
    return torch.tensor(indices, dtype=torch.long), len(indices)

# 读取 .jsonl 和 .json 文件中的数据
def load_data(file_path):
    if file_path.endswith('.jsonl'):
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = [json.loads(line) for line in f.readlines()]
    elif file_path.endswith('.json'):
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = json.load(f)
    else:
        raise ValueError("不支持的文件格式。请使用 .jsonl 或 .json。")
    
    questions = [line['question'] for line in lines]
    answers = [random.choice(line['human_answers'] + line['chatgpt_answers']) for line in lines]
    return questions, answers

# 数据增强函数
def data_augmentation(sentence):
    tokens = tokenize_chinese(sentence)
    augmented_sentence = []
    # 随机插入
    if random.random() < 0.1:
        insert_token = random.choice(list(word2index.keys())[4:])  # 避免插入特殊标记
        insert_index = random.randint(0, len(tokens))
        tokens.insert(insert_index, insert_token)
    # 随机删除
    if random.random() < 0.1 and len(tokens) > 1:
        delete_index = random.randint(0, len(tokens) - 1)
        del tokens[delete_index]
    # 随机交换
    if len(tokens) > 1 and random.random() < 0.1:
        index1, index2 = random.sample(range(len(tokens)), 2)
        tokens[index1], tokens[index2] = tokens[index2], tokens[index1]
    # 同义词替换
    if random.random() < 0.1:
        for i in range(len(tokens)):
            if random.random() < 0.1:
                synonyms = get_synonyms(tokens[i])
                if synonyms:
                    tokens[i] = random.choice(synonyms)
    # 语义保持的句子重写
    if random.random() < 0.1:
        tokens = rewrite_sentence(tokens)
    augmented_sentence = ''.join(tokens)
    return augmented_sentence

# 获取同义词
def get_synonyms(word):
    # 这里可以使用外部库或API来获取同义词
    return []

# 语义保持的句子重写
def rewrite_sentence(tokens):
    # 这里可以使用外部库或API来进行句子重写
    return tokens

# 定义数据集
class ChatDataset(Dataset):
    def __init__(self, questions, answers):
        self.questions = questions
        self.answers = answers

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        input_tensor, input_length = sentence_to_tensor(self.questions[idx])
        target_tensor, target_length = sentence_to_tensor(self.answers[idx])
        return input_tensor, target_tensor, input_length, target_length

# 自定义 collate 函数
def collate_fn(batch):
    inputs, targets, input_lengths, target_lengths = zip(*batch)
    inputs = nn.utils.rnn.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yehaiwz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值