基于 Transformer 的中文对联生成器

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。



简介(Introduction)

本项目是一个基于 Transformer 的中文对联生成器,使用 PyTorch 构建模型,使用 Gradio 构建 Web UI。

数据集:https://www.kaggle.com/datasets/marquis03/chinese-couplets-dataset

GitHub 仓库:https://github.com/Marquis03/Chinese-Couplets-Generator-based-on-Transformer

Gitee 仓库:https://gitee.com/marquis03/Chinese-Couplets-Generator-based-on-Transformer

项目结构(Structure)

.
├── config
│   ├── __init__.py
│   └── config.py
├── data
│   ├── fixed_couplets_in.txt
│   └── fixed_couplets_out.txt
├── dataset
│   ├── __init__.py
│   └── dataset.py
├── img
│   ├── history.png
│   ├── lr_schedule.png
│   └── webui.gif
├── model
│   ├── __init__.py
│   └── model.py
├── trained
│   ├── vocab.pkl
│   └── CoupletsTransformer_best.pth
├── utils
│   ├── __init__.py
│   └── EarlyStopping.py
├── LICENSE
├── README.md
├── requirements.txt
├── train.py
└── webui.py

部署(Deployment)

克隆项目(Clone Project)

git clone https://github.com/Marquis03/Chinese-Couplets-Generator-based-on-Transformer.git
cd Chinese-Couplets-Generator-based-on-Transformer

安装依赖(Requirements)

pip install -r requirements.txt

训练模型(Train Model)

python train.py

Kaggle Notebook: https://www.kaggle.com/code/marquis03/chinese-couplets-generator-based-on-transformer

启动 Web UI(Start Web UI)

python webui.py

项目演示(Demo)

Web UI

Web UI

学习率变化(Learning Rate Schedule)

Learning Rate Schedule

训练历史(Training History)

Training History

代码(Code)

配置参数(Config)

该部分用于配置项目的参数,包括全局参数、路径参数、模型参数、训练参数和日志参数。

对应项目文件为 config/config.py

import os
import sys
import time
import torch
from loguru import logger


class Config:
    def __init__(self):
        # global
        self.seed = 0
        self.cuDNN = True
        self.debug = False
        self.num_workers = 0
        self.str_time = time.strftime("%Y-%m-%dT%H%M%S", time.localtime(time.time()))
        # path
        self.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        self.dataset_dir = os.path.join(self.project_dir, "data")
        self.in_path = os.path.join(self.dataset_dir, "fixed_couplets_in.txt")
        self.out_path = os.path.join(self.dataset_dir, "fixed_couplets_out.txt")
        self.log_dir = os.path.join(self.project_dir, "logs")
        self.save_dir = os.path.join(self.log_dir, self.str_time)
        self.img_save_dir = os.path.join(self.save_dir, "images")
        self.model_save_dir = os.path.join(self.save_dir, "checkpoints")
        for path in (
            self.log_dir,
            self.save_dir,
            self.img_save_dir,
            self.model_save_dir,
        ):
            if not os.path.exists(path):
                os.makedirs(path)
        # model
        self.d_model = 256
        self.num_head = 8
        self.num_encoder_layers = 2
        self.num_decoder_layers = 2
        self.dim_feedforward = 1024
        self.dropout = 0.1
        # train
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = 128
        self.val_ratio = 0.1
        self.epochs = 20
        self.warmup_ratio = 0.12
        self.lr_max = 1e-3
        self.lr_min = 1e-4
        self.beta1 = 0.9
        self.beta2 = 0.98
        self.epsilon = 10e-9
        self.weight_decay = 0.01
        self.early_stop = True
        self.patience = 4
        self.delta = 0
        # log
        logger.remove()
        level_std = "DEBUG" if self.debug else "INFO"
        logger.add(
            sys.stdout,
            colorize=True,
            format="[<green>{time:YYYY-MM-DD HH:mm:ss,SSS}</green>|<level>{level: <8}</level>|<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan>] >>> <level>{message}</level>",
            level=level_std,
        )
        logger.add(
            os.path.join(self.save_dir, f"{self.str_time}.log"),
            format="[{time:YYYY-MM-DD HH:mm:ss,SSS}|{level: <8}|{name}:{function}:{line}] >>> {message}",
            level="INFO",
        )
        logger.info("### Config:")
        for key, value in self.__dict__.items():
            logger.info(f"### {key:20} = {value}")

数据集(Dataset)

该部分用于定义词典、数据集以及相关函数,包括数据的加载、词典的构建、数据集的封装和数据集的加载器。

对应项目文件为 dataset/dataset.py

from collections import Counter

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader


def load_data(filepaths, tokenizer=lambda s: s.strip().split()):
    raw_in_iter = iter(open(filepaths[0], encoding="utf8"))
    raw_out_iter = iter(open(filepaths[1], encoding="utf8"))
    return list(zip(map(tokenizer, raw_in_iter), map(tokenizer, raw_out_iter)))


class Vocab(object):
    UNK = "<unk>"  # 0
    PAD = "<pad>"  # 1
    BOS = "<bos>"  # 2
    EOS = "<eos>"  # 3

    def __init__(self, data=None, min_freq=1):
        counter = Counter()
        for lines in data:
            counter.update(lines[0])
            counter.update(lines[1])
        self.word2idx = {Vocab.UNK: 0, Vocab.PAD: 1, Vocab.BOS: 2, Vocab.EOS: 3}
        self.idx2word = {0: Vocab.UNK, 1: Vocab.PAD, 2: Vocab.BOS, 3: Vocab.EOS}
        idx = 4
        for word, freq in counter.items():
            if freq >= min_freq:
                self.word2idx[word] = idx
                self.idx2word[idx] = word
                idx += 1

    def __len__(self):
        return len(self.word2idx)

    def __getitem__(self, word):
        return self.word2idx.get(word, 0)

    def __call__(self, word):
        if not isinstance(word, (list, tuple)):
            return self[word]
        return [self[w] for w in word]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple, np.ndarray, torch.Tensor)):
            return self.idx2word[int(indices)]
        return [self.idx2word[int(i)] for i in indices]


def pad_sequence(sequences, batch_first=False, padding_value=0):
    max_len = max([s.size(0) for s in sequences])
    out_tensors = []
    for tensor in sequences:
        padding_content = [padding_value] * (max_len - tensor.size(0))
        tensor = torch.cat([tensor, torch.tensor(padding_content)], dim=0)
        out_tensors.append(tensor)
    out_tensors = torch.stack(out_tensors, dim=1)
    if batch_first:
        out_tensors = out_tensors.transpose(0, 1)
    return out_tensors.long()


class CoupletsDataset(Dataset):
    def __init__(self, data, vocab):
        self.data = data
        self.vocab = vocab
        self.PAD_IDX = self.vocab[self.vocab.PAD]
        self.BOS_IDX = self.vocab[self.vocab.BOS]
        self.EOS_IDX = self.vocab[self.vocab.EOS]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        raw_in, raw_out = self.data[index]
        in_tensor_ = torch.LongTensor(self.vocab(raw_in))
        out_tensor_ = torch.LongTensor(self.vocab(raw_out))
        return in_tensor_, out_tensor_

    def collate_fn(self, batch):
        in_batch, out_batch = [], []
        for in_, out_ in batch:
            in_batch.append(in_)
            out_ = torch.cat(
                [
                    torch.LongTensor([self.BOS_IDX]),
                    out_,
                    torch.LongTensor([self.EOS_IDX]),
                ],
                dim=0,
            )
            out_batch.append(out_)
        in_batch = pad_sequence(in_batch, True, self.PAD_IDX)
        out_batch = pad_sequence(out_batch, True, self.PAD_IDX)
        return in_batch, out_batch

    def get_loader(self, batch_size, shuffle=False, num_workers=0):
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            collate_fn=self.collate_fn,
            pin_memory=True,
        )

模型(Model)

该部分用于定义模型,包括 TokenEmbedding、PositionalEncoding 和 CoupletsTransformer。

对应项目文件为 model/model.py

import math
import torch
import torch.nn as nn


class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens):
        return self.embedding(tokens) * math.sqrt(self.emb_size)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


class CoupletsTransformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
    ):
        super(CoupletsTransformer, self).__init__()
        self.name = "CoupletsTransformer"
        self.token_embedding = TokenEmbedding(vocab_size, d_model)
        self.pos_embedding = PositionalEncoding(d_model, dropout)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.fc = nn.Linear(d_model, vocab_size)
        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, tgt, padding_value=0):
        src_embed = self.token_embedding(src)  # [batch_size, src_len, embed_dim]
        src_embed = self.pos_embedding(src_embed)  # [batch_size, src_len, embed_dim]
        tgt_embed = self.token_embedding(tgt)  # [batch_size, tgt_len, embed_dim]
        tgt_embed = self.pos_embedding(tgt_embed)  # [batch_size, tgt_len, embed_dim]

        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(-1)).to(
            tgt.device
        )
        src_key_padding_mask = src == padding_value  # [batch_size, src_len]
        tgt_key_padding_mask = tgt == padding_value  # [batch_size, tgt_len]

        outs = self.transformer(
            src=src_embed,
            tgt=tgt_embed,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
        )  # [batch_size, tgt_len, embed_dim]
        logits = self.fc(outs)  # [batch_size, tgt_len, vocab_size]
        return logits

    def encoder(self, src):
        src_embed = self.token_embedding(src)
        src_embed = self.pos_embedding(src_embed)
        memory = self.transformer.encoder(src_embed)
        return memory

    def decoder(self, tgt, memory):
        tgt_embed = self.token_embedding(tgt)
        tgt_embed = self.pos_embedding(tgt_embed)
        outs = self.transformer.decoder(tgt_embed, memory=memory)
        return outs

    def generate(self, text, vocab):
        self.eval()
        device = next(self.parameters()).device
        max_len = len(text)
        src = torch.LongTensor(vocab(list(text))).unsqueeze(0).to(device)
        memory = self.encoder(src)
        l_out = [vocab.BOS]
        for i in range(max_len):
            tgt = torch.LongTensor(vocab(l_out)).unsqueeze(0).to(device)
            outs = self.decoder(tgt, memory)
            prob = self.fc(outs[:, -1, :])
            next_token = vocab.to_tokens(prob.argmax(1).item())
            if next_token == vocab.EOS:
                break
            l_out.append(next_token)
        return "".join(l_out[1:])

工具(Utils)

该部分用于定义工具函数,包括 EarlyStopping。

对应项目文件为 utils/EarlyStopping.py

class EarlyStopping(object):
    def __init__(self, patience=7, delta=0):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float("inf")
        self.delta = delta

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

训练(Train)

该部分用于定义训练函数,包括训练、验证和保存模型。

对应项目文件为 train.py

import os
import gc
import time
import math
import random
import joblib
import warnings

warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import seaborn as sns
from loguru import logger
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

sns.set_theme(
    style="darkgrid", font_scale=1.2, font="SimHei", rc={"axes.unicode_minus": False}
)

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import LambdaLR


from config import Config
from model import CoupletsTransformer
from dataset import load_data, Vocab, CoupletsDataset
from utils import EarlyStopping


def train_model(
    config, model, train_loader, val_loader, optimizer, criterion, scheduler
):
    model = model.to(config.device)
    best_loss = float("inf")
    history = []
    model_path = os.path.join(config.model_save_dir, f"{model.name}_best.pth")
    if config.early_stop:
        early_stopping = EarlyStopping(patience=config.patience, delta=config.delta)
    for epoch in tqdm(range(1, config.epochs + 1), desc=f"All"):
        train_loss = train_one_epoch(
            config, model, train_loader, optimizer, criterion, scheduler
        )
        val_loss = evaluate(config, model, val_loader, criterion)

        perplexity = math.exp(val_loss)
        history.append((epoch, train_loss, val_loss))
        msg = f"Epoch {epoch}/{config.epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, Perplexity: {perplexity:.4f}"
        logger.info(msg)
        if val_loss < best_loss:
            logger.info(
                f"Val loss decrease from {best_loss:>10.6f} to {val_loss:>10.6f}"
            )
            torch.save(model.state_dict(), model_path)
            best_loss = val_loss
        if config.early_stop:
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                logger.info(f"Early stopping at epoch {epoch}")
                break
    logger.info(f"Save best model with val loss {best_loss:.6f} to {model_path}")

    model_path = os.path.join(config.model_save_dir, f"{model.name}_last.pth")
    torch.save(model.state_dict(), model_path)
    logger.info(f"Save last model with val loss {val_loss:.6f} to {model_path}")

    history = pd.DataFrame(
        history, columns=["Epoch", "Train Loss", "Val Loss"]
    ).set_index("Epoch")
    history.plot(
        subplots=True, layout=(1, 2), sharey="row", figsize=(14, 6), marker="o", lw=2
    )
    history_path = os.path.join(config.img_save_dir, "history.png")
    plt.savefig(history_path, dpi=300)
    logger.info(f"Save history to {history_path}")


def train_one_epoch(config, model, train_loader, optimizer, criterion, scheduler):
    model.train()
    train_loss = 0
    for src, tgt in tqdm(train_loader, desc=f"Epoch", leave=False):
        src, tgt = src.to(config.device), tgt.to(config.device)
        output = model(src, tgt[:, :-1], config.PAD_IDX)
        output = output.contiguous().view(-1, output.size(-1))
        tgt = tgt[:, 1:].contiguous().view(-1)
        loss = criterion(output, tgt)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    return train_loss / len(train_loader)


def evaluate(config, model, val_loader, criterion):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt in tqdm(val_loader, desc=f"Val", leave=False):
            src, tgt = src.to(config.device), tgt.to(config.device)
            output = model(src, tgt[:, :-1], config.PAD_IDX)
            output = output.contiguous().view(-1, output.size(-1))
            tgt = tgt[:, 1:].contiguous().view(-1)
            loss = criterion(output, tgt)
            val_loss += loss.item()
    return val_loss / len(val_loader)


def test_model(model, data, vocab):
    model.eval()
    for src_text, tgt_text in data:
        src_text, tgt_text = "".join(src_text), "".join(tgt_text)
        out_text = model.generate(src_text, vocab)
        logger.info(f"\nInput: {src_text}\nTarget: {tgt_text}\nOutput: {out_text}")


def seed_everything(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def main():
    config = Config()

    # Set random seed
    seed_everything(config.seed)
    logger.info(f"Set random seed to {config.seed}")

    # Set cuDNN
    if config.cuDNN:
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

    # Load data
    data = load_data([config.in_path, config.out_path])
    if config.debug:
        data = data[:1000]
    logger.info(f"Load {len(data)} couplets")

    # Build vocab
    vocab = Vocab(data)
    vocab_size = len(vocab)
    logger.info(f"Build vocab with {vocab_size} tokens")
    vocab_path = os.path.join(config.model_save_dir, "vocab.pkl")
    joblib.dump(vocab, vocab_path)
    logger.info(f"Save vocab to {vocab_path}")

    # Build dataset
    data_train, data_val = train_test_split(
        data, test_size=config.val_ratio, random_state=config.seed, shuffle=True
    )
    train_dataset = CoupletsDataset(data_train, vocab)
    val_dataset = CoupletsDataset(data_val, vocab)

    config.PAD_IDX = train_dataset.PAD_IDX

    logger.info(f"Build train dataset with {len(train_dataset)} samples")
    logger.info(f"Build val dataset with {len(val_dataset)} samples")

    # Build dataloader
    train_loader = train_dataset.get_loader(
        config.batch_size, shuffle=True, num_workers=config.num_workers
    )
    val_loader = val_dataset.get_loader(
        config.batch_size, shuffle=False, num_workers=config.num_workers
    )
    logger.info(f"Build train dataloader with {len(train_loader)} batches")
    logger.info(f"Build val dataloader with {len(val_loader)} batches")

    # Build model
    model = CoupletsTransformer(
        vocab_size=vocab_size,
        d_model=config.d_model,
        nhead=config.num_head,
        num_encoder_layers=config.num_encoder_layers,
        num_decoder_layers=config.num_decoder_layers,
        dim_feedforward=config.dim_feedforward,
        dropout=config.dropout,
    )
    logger.info(f"Build model with {model.name}")

    # Build optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=1,
        betas=(config.beta1, config.beta2),
        eps=config.epsilon,
        weight_decay=config.weight_decay,
    )

    # Build criterion
    criterion = nn.CrossEntropyLoss(ignore_index=config.PAD_IDX, reduction="mean")

    # Build scheduler
    lr_max, lr_min = config.lr_max, config.lr_min
    T_max = config.epochs * len(train_loader)
    warm_up_iter = int(T_max * config.warmup_ratio)

    def WarmupExponentialLR(cur_iter):
        gamma = math.exp(math.log(lr_min / lr_max) / (T_max - warm_up_iter))
        if cur_iter < warm_up_iter:
            return (lr_max - lr_min) * (cur_iter / warm_up_iter) + lr_min
        else:
            return lr_max * gamma ** (cur_iter - warm_up_iter)

    scheduler = LambdaLR(optimizer, lr_lambda=WarmupExponentialLR)

    df_lr = pd.DataFrame(
        [WarmupExponentialLR(i) for i in range(T_max)],
        columns=["Learning Rate"],
    )
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=df_lr, linewidth=2)
    plt.title("Learning Rate Schedule")
    plt.xlabel("Iteration")
    plt.ylabel("Learning Rate")
    lr_img_path = os.path.join(config.img_save_dir, "lr_schedule.png")
    plt.savefig(lr_img_path, dpi=300)
    logger.info(f"Save learning rate schedule to {lr_img_path}")

    # Garbage collect
    gc.collect()
    torch.cuda.empty_cache()

    # Train model
    train_model(
        config, model, train_loader, val_loader, optimizer, criterion, scheduler
    )

    # Test model
    test_model(model, data_val[:10], vocab)


if __name__ == "__main__":
    main()

Web UI

该部分用于定义 Web UI,包括输入、输出和启动 Web UI。

对应项目文件为 webui.py

import random
import joblib

import torch
import gradio as gr

from dataset import Vocab
from model import CoupletsTransformer

data_path = "./data/fixed_couplets_in.txt"
vocab_path = "./trained/vocab.pkl"
model_path = "./trained/CoupletsTransformer_best.pth"


vocab = joblib.load(vocab_path)
vocab_size = len(vocab)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CoupletsTransformer(
    vocab_size,
    d_model=256,
    nhead=8,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=1024,
    dropout=0.1,
).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

example = (
    line.replace(" ", "").strip() for line in iter(open(data_path, encoding="utf8"))
)
example = [line for line in example if len(line) > 5]

example = random.sample(example, 300)


def generate_couplet(vocab, model, src_text):
    if not src_text:
        return "上联不能为空"
    out_text = model.generate(src_text, vocab)
    return out_text


input_text = gr.Textbox(
    label="上联",
    placeholder="在这里输入上联",
    max_lines=1,
    lines=1,
    show_copy_button=True,
    autofocus=True,
)

output_text = gr.Textbox(
    label="下联",
    placeholder="在这里生成下联",
    max_lines=1,
    lines=1,
    show_copy_button=True,
)

demo = gr.Interface(
    fn=lambda x: generate_couplet(vocab, model, x),
    inputs=input_text,
    outputs=output_text,
    title="中文对联生成器",
    description="输入上联,生成下联",
    allow_flagging="never",
    submit_btn="生成下联",
    clear_btn="清空",
    examples=example,
    examples_per_page=50,
)

demo.launch()
  • 9
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小嗷犬

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值