NLP藏头诗写作

这是一个类似MNIST和CIFAR分类的简单项目,可以看作NLP领域的入门。

数据预处理

我选的是史上最全的诗歌数据集https://github.com/chinese-poetry/chinese-poetry/

数据集是真的很全,横跨中华上下五千年。下载到本地后,我们只选择在json文件夹中的唐诗(即poet.tang.xxxxx.json),然而已经很多了,有四万多首。

我们需要的只是单纯的诗歌文本供网络训练,因此把json文件处理一下写入txt。

创建writedata.py,将数据全部写入poets.txt,为了保证诗歌格式一致,我们只要五言律诗,所以要用正则匹配一下符合规则的诗。

import json,re
poets=open('poets.txt','w',encoding='utf-8')
for x in range(0,58000,1000):#数据集的唐诗编号从0到57000
  filename='poet.tang.'+str(x)+'.json'
  localfile=json.load(open(filename,'r',encoding='utf-8'))#用json库把数据转发为字典
  for poet in localfile:
      self=re.findall(r'\w{5},\w{5}。|\w{7},\w{7}。',''.join(poet['paragraphs']))#我们只要五言和七言的诗句
      if self:poets.writelines(''.join(self)+'\n')#每行记录一首诗

接下来就是自定义一下数据集,和cv的数据集一样。只不过由于pytorch不支持直接中文字符类型的输入,因此我们建立每个字符到整数的dict索引,这样也方便后续的onehot编码,并定义两个关键字符作为诗句的开始和结束。遍历数据集时,我们遍历的是一个个的字符索引。

class Mydate(Dataset):
  def __init__(self,file="D:\\Documents\\Dataset\\chinese-poetry-master\\json\\poets.txt",seq_len=48):#定义我们数据存放的文件和一段数据的长度
    sos,eos=0,1#表示每段数据的起始和结束
    self.seq_len=seq_len
    with open(file,encoding='utf-8') as f:
      lines=f.read().splitlines()
      self.wordindex={'<SOS>':sos,'<EOS>':eos}#中文字符到索引的映射
      point,wordnum=list(),0
      for line in lines:
        point.append(sos)
        for word in line:
          if word not in self.wordindex:
            self.wordindex[word]=wordnum
            wordnum+=1
          point.append(self.wordindex[word])
        point.append(eos)
    self.indexword={y:x for x,y in self.wordindex.items()}#索引到中文字符的映射,方便还原
    self.data=np.array(point,dtype=np.int64)
  def __len__(self):
    return (len(self.data)-1)//self.seq_len
  def __getitem__(self,i):
    start=i*self.seq_len
    end=start+self.seq_len
    return(torch.as_tensor(self.data[start:end]),torch.as_tensor(self.data[start+1:end+1]))#返回值分为网络的输入和输出,两者相差一个字符

建议再把整数还原回汉字,看看对不对。

 

 

网络搭建

网络的结构可以很简单,因为我们实现的是最基本的序列生成模型。这里我只用了一个Embedding层+LSTM单元,最后接两个线性层用于vec2word。

​class PoetNet(nn.Module):
  def __init__(self, vocab_size, embed_size, hidden_size, lstm_layers):
    super().__init__()


    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.lstm = nn.LSTM(embed_size, hidden_size, lstm_layers, batch_first=True)
    self.h2h = nn.Linear(hidden_size, hidden_size)
    self.h2o = nn.Linear(hidden_size, vocab_size)


  def forward(self, word_ids, lstm_hidden=None):
    embedded = self.embedding(word_ids)
    lstm_out, lstm_hidden = self.lstm(embedded, lstm_hidden)
    out = self.h2h(lstm_out)
    out = self.h2o(out)
    return out, lstm_hidden

网络训练

训练过程也跟mnist差不多,只不过前向传播一次是给网络输入一个字,然后网络预测下一个字,和原诗句对应文字做交叉熵运算,然后计算更新梯度,很快网络就会输出xxxxx,xxxxx。这种格式的诗句了。

由于我设置的<start>标识和<end>标识在诗句的开头和句号结尾处,所以我们给网络一个字他会输出xxxxx,xxxxx。这样的一整句才结束。

网络一次ep的训练过程

def training_step():
  for i, (input_, target) in enumerate(train_loader):
    model.train()
    input_, target = input_.to(device), target.to(device)
    output, _ = model(input_)
    loss =F.cross_entropy(output.reshape(-1, vocab_size), target.flatten())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    acc = accuracy(output, target)
    print(
      "Training: Epoch=%d, Batch=%d/%d, Loss=%.4f, Accuracy=%.4f"
      % (epoch, i, len(train_loader), loss.item(), acc)
    )

最后,效果如图

全部代码

import glob
import datetime

import numpy as np

import torch
import torch.optim as op
from torch import nn
import torch.nn.functional as F

from torch.utils.data import DataLoader,Dataset,random_split
from torch.utils.tensorboard import SummaryWriter

#初始参数设定
debug=False
embed_size=100
hidden_size=1024
lr=0.001
lstm_layers=2
batch_size=32
epochs=128
seq_len = 64
checkpoint_dir="./"
#数据集对象
class Mydata(Dataset):
  def __init__(self,file="./CPM/json/poets.txt",seq_len=64):
    sos,eos=0,1
    self.seq_len=seq_len
    with open(file,encoding='utf-8') as f:
      lines=f.read().splitlines()
      self.wordindex={'<SOS>':sos,'<EOS>':eos}
      point,wordnum=list(),0
      for line in lines:
        point.append(sos)
        for word in line:
          if word not in self.wordindex:
            self.wordindex[word]=wordnum
            wordnum+=1
          point.append(self.wordindex[word])
        point.append(eos)
    self.indexword={x:y for y,x in self.wordindex.items()}
    self.data=np.array(point,dtype=np.int64)
  def __len__(self):
    return (len(self.data)-1)//self.seq_len
  def __getitem__(self,i):
    start=i*self.seq_len
    end=start+self.seq_len
    return(torch.as_tensor(self.data[start:end]),torch.as_tensor(self.data[start+1:end+1]))
class PoetNet(nn.Module):
  def __init__(self, vocab_size, embed_size, hidden_size, lstm_layers):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.lstm = nn.LSTM(embed_size, hidden_size, lstm_layers, batch_first=True)
    self.h2h = nn.Linear(hidden_size, hidden_size)
    self.h2o = nn.Linear(hidden_size, vocab_size)

  def forward(self, word_ids, lstm_hidden=None):
    embedded = self.embedding(word_ids)

    lstm_out, lstm_hidden = self.lstm(embedded, lstm_hidden)

    out = self.h2h(lstm_out)

    out = self.h2o(out)

    return out, lstm_hidden

def accuracy(output, target):
  """Compute the accuracy between model output and ground truth.
  Args:
      output: (batch_size, seq_len, vocab_size)
      target: (batch_size, seq_len)
  Returns:
      float: accuracy value between 0 and 1
  """
  output = output.reshape(-1, vocab_size)
  target = target.flatten()

  a = output.topk(1).indices.flatten()
  b = target
  return a.eq(b).sum().item() / len(a)


def generate(start_phrases):
  start_phrases = start_phrases.split("。")

  hidden = None

  def next_word(input_word):
    nonlocal hidden
    input_word_index = dataset.wordindex[input_word]
    input_ = torch.Tensor([[input_word_index]]).long().to(device)
    output, hidden = model(input_, hidden)
    top_word_index = output[0].topk(1).indices.item()

    return dataset.indexword[top_word_index]

  result = []  # a list of output words
  cur_word = "。"

  for i in range(seq_len):
    if cur_word == "。":  # end of a sentence
      result.append(cur_word)
      next_word(cur_word)

      if len(start_phrases) == 0:
        break

      for w in start_phrases.pop(0):
        result.append(w)
        cur_word = next_word(w)

    else:
      result.append(cur_word)
      cur_word = next_word(cur_word)

  result = "".join(result)
  result = result.strip("。")
  return result


def training_step():
  for i, (input_, target) in enumerate(train_loader):
    model.train()

    input_, target = input_.to(device), target.to(device)

    output, _ = model(input_)
    loss =F.cross_entropy(output.reshape(-1, vocab_size), target.flatten())

    optimizer.zero_grad()  # Make sure gradient does not accumulate
    loss.backward()  # Compute gradient
    optimizer.step()  # Update NN weights

    acc = accuracy(output, target)

    print(
      "Training: Epoch=%d, Batch=%d/%d, Loss=%.4f, Accuracy=%.4f"
      % (epoch, i, len(train_loader), loss.item(), acc)
    )

    if not debug:
      step = epoch * len(train_loader) + i
      writer.add_scalar("loss/training", loss.item(), step)
      writer.add_scalar("accuracy/training", acc, step)

      if i % 50 == 0:
        generated_lyrics = generate("皇。函。谷。居")
        writer.add_text("generated_lyrics", generated_lyrics, i)
        writer.flush()


def evaluation_step():
  model.eval()

  epoch_loss = 0
  epoch_acc = 0

  with torch.no_grad():
    for data in test_loader:
      input_, target = data[0].to(device), data[1].to(device)

      output, _ = model(input_)
      loss = F.cross_entropy(output.reshape(-1, vocab_size), target.flatten())

      epoch_acc += accuracy(output, target)

      epoch_loss += loss.item()

  epoch_loss /= len(test_loader)
  epoch_acc /= len(test_loader)
  print(
    "Validation: Epoch=%d, Loss=%.4f, Accuracy=%.4f"
    % (epoch, epoch_loss, epoch_acc)
  )

  if not debug:
    writer.add_scalar("loss/validation", epoch_loss, epoch)
    writer.add_scalar("accuracy/validation", epoch_acc, epoch)
    writer.flush()


def save_checkpoint():
  torch.save(
    {
      "epoch": epoch,
      "model_state_dict": model.state_dict(),
      "optimizer_state_dict": optimizer.state_dict(),
    },
    "checkpoint-%s.pth" % datetime.datetime.now().strftime("%y%m%d-%H%M%S"),
  )


def load_checkpoint(file):
  global epoch

  ckpt = torch.load(file)

  print("Loading checkpoint from %s." % file)

  model.load_state_dict(ckpt["model_state_dict"])

  optimizer.load_state_dict(ckpt["optimizer_state_dict"])

  epoch = ckpt["epoch"]


if __name__ == "__main__":
  # Create cuda device to train model on GPU
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  # Define dataset
  dataset = Mydata(seq_len=seq_len)

  # Split dataset into training and validation
  data_length = len(dataset)
  lengths = [int(data_length - 1000), 1000]
  train_data, test_data = random_split(dataset, lengths)

  # Create data loader
  train_loader = DataLoader(
    train_data, batch_size=batch_size, shuffle=True, num_workers=0
  )
  test_loader = DataLoader(
    test_data, batch_size=batch_size, shuffle=True, num_workers=0
  )
  if debug:
    train_loader = [next(iter(train_loader))]
    test_loader = [next(iter(test_loader))]

  # Create NN model
  vocab_size = len(dataset.wordindex)
  model = PoetNet(
    vocab_size=vocab_size,
    embed_size=embed_size,
    hidden_size=hidden_size,
    lstm_layers=lstm_layers,
  )
  model = model.to(device)

  # Create optimizer
  optimizer = op.Adam(model.parameters(), lr=lr)
  # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

  # Load checkpoint
  checkpoint_files = glob.glob("checkpoint-*.pth")
  if (
      not debug
      and len(checkpoint_files) > 0
      and input("Enter y to load %s: " % checkpoint_files[-1]) == "y"
  ):
    load_checkpoint(checkpoint_files[-1])
  else:
    epoch = 0

  if (
      input("Enter y to enter inference mode, anything else to enter training mode: ")
      == "y"
  ):
    # Inference loop
    while True:
      start_words = input("Enter start-words divided by '。' (e.g. '深。度。学。习'): ")
      if not start_words:
        break

      print(generate(start_words))

  else:
    if not debug:
      writer = SummaryWriter()

    # Optimization loop
    while epoch < epochs:
      training_step()
      evaluation_step()

      if not debug:
        save_checkpoint()

      epoch += 1

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值