transformer和文本纠错初探

一、transformer简介

         Transformer模型是谷歌用于机器翻译而提出的seq2seq模型,在当时达到了sota效果。它改进了RNN训练慢的特点,同时利用了attention机制实现快速并行以及关键信息的聚焦。Transformer由论文《Attention is All You Need》提出,论文地址——Attention is All You Need:https://arxiv.org/abs/1706.03762。Transformer架构如下:

      可以看到左边是N个encoder模块,右边是N个decoder模块。输入是token_embedding和position_embedding。可以看到核心的模块就是多头attention、残差模块和前馈神经网络。不同的是encoder中attention只是使用了self-attention机制,而decoder中使用了self-attention和cross-attention机制(这里的输入就是每一层encoder的输出,cross-attention模块是有mask的,目的是防止token与它之后的token去做attention,也就是left-to-right模型)

decoder端的Mask的作用——在训练和推理阶段保持一致性。在推理阶段,token是从左往右的顺序进行推理的。在推理timestep时刻的token时候,decoder只能看到timestep<T的T-1个token,不能和timestep大于它自身的token做attention(因为根本还不知道后面的token是什么),为了保证训练时和推理时的一致性,所以,训练时要同样防止token与它之后的token去做attention——这个是论文的解释。以上,在实现的时候decoder的mask特别重要。

FFN前馈神经网络

网络中的FFN前馈神经网络具体的作用就是对embedding进行空间变换,引入激活函数relu,变换了attention output的空间,因此增加了模型的非线性,增加了模型的表现能力。

position_embedding

position_embedding这个部分就是用来捕捉每个token的位置信息,如果没有position_embedding模型完全就不能工作。而且这也被训练出来的。position_embedding和每一个token都是唯一绑定的。论文中使用的Positional Encoding(PE)是正余弦函数,学习到了相对位置关系,bert中是使用的nn.embedding来学习位置编码的。

 

残差连接

如图所示红色线条部分就是残差连接——残差网络的思想。它的作用是什么呢?

它能解决神经网络退化的问题(随着网络深度增加,网络模型的表现先是逐渐增加然后至饱和然后迅速下降),另一方面错误信号可以不经过任何中间权重矩阵变换直接传播到低层,一定程度上可以缓解梯度弥散问题。综上,残差连接可以使得信息前后向传播更加顺畅。

 

二、transformer的实现

由于transformer很火,网上也有很多版本的实现,这里简单的将这些代码做一下整理。看看最简单的实现,直接调用torch封装的API。这里没有把多头Attention、前馈神经网络、层归一化等细节模块的实现呈现出来,都是使用了高阶函数封装起来了。

1、基于torch封装的transformer组件实现transformer

import torch.nn as nn
import torch.nn.modules.transformer as T
from .PositionalEncoding import PositionalEncoding
from .Embeddings import Embeddings


class Transformer(nn.Module):
    def __init__(self,max_len: int = 64,num_of_vocab: int=21128, d_model: int = 768, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: str = "relu"):
        super(Transformer, self).__init__()

        #encoder由encoder_layer和encoder_norm组成
        encoder_norm = T.LayerNorm(d_model)
        encoder_layer = T.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
        self.encoder = T.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        # decoder由decoder_layer和decoder_norm
        decoder_norm = T.LayerNorm(d_model)
        decoder_layer = T.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
        self.decoder = T.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self.embeddings = Embeddings(d_model, num_of_vocab)
        self.positional_encoding = PositionalEncoding(d_model=d_model,max_len=max_len)
        self.projection = nn.Linear(d_model, num_of_vocab, bias=False)

    def forward(self,src, tar, src_key_padding_mask, tgt_key_padding_mask,src_mask,tar_mask,memory_key_padding_mask):
        """
        src: 源文本Ids --(N,S,E)
        tar:目标文本Ids --(N,T,E)
        src_key_padding_mask:源文本关键掩码 --(S,S)
        tgt_key_padding_mask:目标文本掩码 --(T,T)
        src_mask:the additive mask for the src sequence --(N,S)
        tar_mask:the additive mask for the tgt sequence --(N,T)
        N:batch_size
        S T:序列长度


        decoder的输出[T,N,E]
        """

        src_embedding = self.embeddings(src)
        src_embedding = self.positional_encoding(src_embedding)
        
        #维度要变换一下,(N,S,E)----->(S,N,E)
        src_embedding = src_embedding.permute(1,0,2)
        memory = self.encoder(src_embedding,mask=src_mask, 
        src_key_padding_mask=src_key_padding_mask)

        #decoder输入的位置编码和token_embedding
        tar_embedding = self.embeddings(tar)
        tar_embedding = self.positional_encoding(tar_embedding)
        tar_embedding = tar_embedding.permute(1, 0, 2)

       

        #decoder的输入是encoder的输出mermory和tar_embedding以及相关的mask
        decoder_out = self.decoder(tar_embedding,memory,tgt_mask=tar_mask,tgt_key_padding_mask=tgt_key_padding_mask,memory_key_padding_mask=memory_key_padding_mask)

        decoder_out = decoder_out.permute(1,0,2)
        out = self.projection(decoder_out)
        return out

上述代码中关于transformer模型的构建都是直接调用API函数来实现,非常简单和方便。值得注意的是encoder和decoder输入和输出具体形式;encoder需要输入mask编码、token_emebdding和positional_encoding;decoder的输入则是tar_embedding(包含位置编码和token_embedding)、mask和encoder的输出。具体的可以看下图:

在实现的时候需要注意embedding和mask的维度,不能弄错了,弄错了就会报错。

2、position_embedding的实现

关于位置编码,bert的实现中一般采用的是nn.embedding()这种可学习来实现的;而transformer的实现中一般采用的是正余弦函数来实现的。其公式如下:

这两种方式其实效果是差不多的,有论文做了分析。不过使用正余弦函数的这种实现方式可以编码超出最大序列长度的token。

代码如下:

import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

注意的是经过上面PositionalEncoding模块的处理后,这里不仅仅是位置编码,而是位置编码+token_embedding。

其中有几句代码比较难以理解,这里其实是上面的公式经过了变形推导了的,如下图:

这样下面的4句代码理解起来就容易多了。

        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

3、token_embedding 

一般都是直接使用nn.embedding(vaca_len,d_model)来实现词嵌入的初始化,代码如下:

import torch.nn as nn
import math

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

有一个地方值得注意就是最后的词嵌入为何要乘以embedding size的开方?

参考解释:因为embedding matrix的初始化方式是xavier init,这种方式的方差是1/embedding size,因此乘以embedding size的开方使得embedding matrix的方差是1,在这个scale下可能更有利于embedding matrix的收敛。

4、mask

上面transformer模型API给出的参数中,encoder使用到了src_masksrc_key_padding_mask,decoder使用了tgt_masktgt_key_padding_mask。这里*mask和*padding_mask分别对应两种不同的mask。

先看*_padding_mask

这种mask就是为了把一个序列中起padding作用的无效的位置mask起来,不用参加attention机制的计算。

看一个例子:

序列

我喜欢自然语言处理

最大长度为15的的话,就要做如下padding处理

我喜欢自然语言处理[pad][pad][pad][pad][pad][pad]

映射成mask的话:

111111111000000

上面的这种mask就和bert中的mask输入有点类似了,但是在torch封装的transformer中一般都还要变成bool类型;

    mask = torch.tensor(mask,dtype=torch.long)
    mask = mask < 1

结果如下:

tensor([False, False, False, False, False, False, False, False, False,  True,True,  True,  True,  True,  True])

tgt_mask这一类型的mask就是seq_mask。decoder端在做self-attention的时候,由于transformer在训练的时候是可并行的,它只能看到上文的信息,不能看到当前词之后的信息。在预测的时候也是需要seq_mask的好的一个解释——在测试或者预测时,Transformer里decoder为什么还需要seq mask?——解码器mask设计思路的原因:1、预测阶段要保持重复预测词一致——>必须保持每步attention的值不变——>掩码掉未来词——>mask下三角矩阵;2、恰好也可以使模型在训练阶段的传播过程与预测阶段一致。

实现思路,实现一个上三角矩阵,其他位置为0;然后把0用负无穷来替换,1换成0;这样的矩阵在做self_attention的时候就可以屏蔽掉未来的信息。它的形状是一个二维矩阵(L,L)——L代表序列长度。基于torch很容易实现:

def gen_nopeek_mask(length):
    """
     Returns the nopeek mask
             Parameters:
                     length (int): Number of tokens in each sentence in the target batch
             Returns:
                     mask (arr): tgt_mask, looks like [[0., -inf, -inf],
                                                      [0., 0., -inf],
                                                      [0., 0., 0.]]
     """
    mask = torch.triu(torch.ones(length, length))
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

例如序列长度是15,decoder的mask矩阵就是:

三、模型训练和预测

模型训练其中主要的就是要encoder和decoder的每一个输入一定要正确,主要是token_embedding、paddin_mask和attn_mask以及position_embedding;另外还有就是训练过程中loss计算的理解——这里是计算一个batch内所有序列的每个字和label对应的交叉熵损失,比简单的分类任务要稍微难理解一点点。

训练阶段:

把src和tag文本添加起始和结尾标志符('SOS','BOS'都是可以的),做embedding处理后,同时输入到Transformer模型中去。最终得到output,这个是一个概率分布。接上nn.Linear(d_model, vocab_size)全连接层后就可以进行loss计算了。

测试推理阶段

首先把需要预测的序列输入encoder中,然后第一步:decoder输入起始符号,预测得到一个字;第二步:decoder输入[cls]+第一步的预测结果,又预测得到一个字,……,重复上述步骤直到序列最大长度或者预测得到[SEP]。

loss计算的理解

transformer最终输出的结果output维度是和transformer decoder的tag输入维度是一样的[S,B,D]——S序列长度,B是batch_size,D词嵌入特征数目。

在经过一个nn.Linear(d_model, vocab_size)这样的全连接后,维度信息是[S,B,vocab_size]

而tar_label[B,S]

要是它们能够计算交叉熵损失,则需要把output维度变换为[(B*S),vocab_size],tar_label维度变换为[(B*S),1]——这样就表示一个batch内所有的字的概率分布和label进行损失计算。

简单实现:

loss = criterion(output.contiguous().view(-1, output.size(-1)), tar_outputs.contiguous().view(-1))

注意——x.contiguous().view()和x.reshape()具有相似的作用

四、解码策略

在推理预测阶段,怎么来确定最终的预测结果呢?选用不同策略会有什么样的差异呢?下面详细的进行介绍。

1、贪心搜索(greedy search)

transformer预测阶段得到的概率分布,连接全连接层后,可以得到一个序列的概率分布[(B*S),vocab_size]——含义就是每个字在词表上的概率分布,共有B*S个字。怎么样通过这个概率分布得到最合理的序列。一种很直观的做法就是从每个字的概率分布中取它的最大概率的那个可能性,直到整个序列完成或者发现终止符[SEP]。在transformer中由于decoder的输入在推理预测的时候有点不一样,实现起来的时候也要注意。由上述推力测试阶段简示图,可以发现最开始输入的是起始符[CLS]。实现代码如下:

def greedy_search_decode(model, src,src_key_padding_mask, max_len:int = 64, start_symbol:int = 1):
    """
    :param model: Transformer model
    :param src: the encoder input
    :param max_len: 序列最大长度
    :return:ys 这个就是预测的具体序列
    解码的时候这几个mask是不能够少的
    """
    src_mask = gen_nopeek_mask(src.shape[1]).to(device)
    memory_key_padding_mask = src_key_padding_mask
    
    #最开始的字符[CLS]在词表的位置是1
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        tar_mask = gen_nopeek_mask(ys.shape[1]).to(device)
        out = model.forward(src, ys, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=None,src_mask=src_mask,tar_mask=tar_mask,memory_key_padding_mask=memory_key_padding_mask)
        #预测结果out,选取最后一个概率分布
        out = out[:,-1,:]
        #得到最大的那个概率的index,就是该次预测的字在词表的index
        _, next_word = torch.max(out, dim=1)
        next_word = next_word.data[0]
        if next_word != 2:
            #如果没有预测出终止符[SEP]
            #把这次预测的结果和以前的结果cat起来,再次循环迭代预测
            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        else:
            break
    return ys

上述实现的贪心搜索算法,特点是需要循环迭代,不能按批次来处理,速度比较慢。在效果上来看,由于贪心搜索每次预测的时候都是考虑了前一时刻的概率,可能会隐藏未来时刻的最大概率的字符。如下图:

我们使用The 预测得到的最大概率字符是nice,然后用the nice 预测得到的最大字符是 women;而假如第一步the 预测得到的字符不选nice而是选择概率第二大的dog,第那么第二步使用the dog 预测得到的结果就是has,has具有0.9的概率;贪心搜索就会把第二种情况给排除掉;从合理性上来说,贪心搜索得到的序列可能不是最具有合理性的。也就是greedy search找到的序列并不是最大概率序列,为了找到更加合理的序列,提出了集束搜索,找到更大概率序列。

2、集束搜索(beam search)

思想——每一步解码时,仅保留前k个可能的结果。例如在第一步解码时,我们选择前  k个可能的结果  ,分别代入第二步解码中,各取前 k 个候选词,即得到 k平方个候选组合,最后保留概率乘积最大的前  k个候选结果。

当beam size为2时,以上图为例。第一步解码,我们选择概率最大的两个单词[A, C],然后分别带入第二步解码,分别得到[AA, AB, AC, AD, AE, CA, CB, CC, CD, CE] 10种情况,这里仅保留最优的两种情况[AB, CE],然后再继续带入第三步解码。最后得到整体概率最大的序列。

就算有上面的算法,文本预测和生成的时候还是会出现重复的词语或者空洞的词语,有学者提出了对beam search进行改进。

比较详细的介绍参考文章——解读Beam Search

五、文本纠错实战

文本纠错其实也是一个比较难的任务,最近老板比较感兴趣,就进行了学习和实验。

一般而言,使用bert模型的MLM任务就可以把错误的字给纠正过来,有很多人文本纠错就是这么做的。但是这里是不能对缺少字符的和重复字符文本进行纠正。那么我们这里把文本纠错看做一个翻译任务,把错误的文本翻译为正确的文本,这样就可以对缺少字符的和重复字符文本进行纠正。采用深度学习中大名鼎鼎的transformer模型。前面已经分析过transformer模型训练和预测的一些步骤和技术细节。这里对文本纠错做一个全面的展示。

训练集:这里我们把LQCMC整理的相似句子对以及我们自己爬去的专利title名做为tag文本,对句子随机的删除一个字符、替代一个字符以及增加一个字符作为src文本。

为了充分构造足够多的数据集,在实现datareader的时候,采用动态产生噪声数据,那么原始的src和tar都是正确的文本,在datareader的getitem函数中把src文本添加一些噪声就好了。这里的好处是可以构建很多数据,当然可能不是每次构造的数据都是之前训练集中没有的,优点就是在有限的内存下可以采用时间换空间的方法进行模型的训练。

dataReader的代码:

from torch.utils.data import Dataset
import os
import torch
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)
import json


from dynamic_produce_noise import produce_noise

class DataReader_dynamic(Dataset):
    def __init__(self,args,filename,repeat=1):
        self.max_sentence_length = 64
        self.repeat = repeat
        with open('chinese_vacabulary/token_to_id.json','r') as f:
            self.vaca = json.load(f)

        self.data_list = self.read_file(args.data_dir,filename)
        self.len = len(self.data_list)

    def read_file(self,data_dir,filename):
        process_data_list = []
        file_cach_path = os.path.join(data_dir, "cached_{}".format(filename.split('.')[0]))
        if os.path.exists(file_cach_path):
            process_data_list = torch.load(file_cach_path)
            return process_data_list
        else:
            file_path = os.path.join(data_dir,filename)
            with open(file_path,'r') as f:
                lines = f.readlines()
            for line in tqdm(lines[1:]):
                line = line.strip('\n').split('\t')
                src, tar = line[0],line[1]


                src_input_ids,src_key_padding_mask, src_len = self.convert_into_indextokens_and_segment_id_nopadding(src)
                tar_input_ids,tgt_key_padding_mask,tar_len = self.convert_into_indextokens_and_segment_id_padding(tar)


                tar_input_ids = torch.tensor(tar_input_ids,dtype=torch.long)
                tgt_key_padding_mask = torch.tensor(tgt_key_padding_mask, dtype=torch.long)

                #转化为bool型

                tgt_key_padding_mask = tgt_key_padding_mask < 1


                tar_len = torch.tensor(tar_len,dtype=torch.long)

                process_data_list.append((src_input_ids,tar_input_ids,src_key_padding_mask,tgt_key_padding_mask,src_len, tar_len))

            logger.info('Saving tokenizering into cached file %s', file_cach_path)
            torch.save(process_data_list, file_cach_path)
            return process_data_list

    def tokenize(self,text,vaca):
        text = text.split('\t')
        ids = [ vaca[word] if word in vaca else vaca['[OOV]']  for word in text]
        return ids

    def convert_into_indextokens_and_segment_id_padding(self, text):
        text = text[:self.max_sentence_length - 2]
        text = '[CLS]' + '\t' + '\t'.join(text) + '\t' + '[SEP]'
        input_ids = self.tokenize(text, self.vaca)
        attention_mask = [1] * len(input_ids)
        seq_len = len(text)

        pad_indextokens = [0] * (self.max_sentence_length - len(input_ids))
        input_ids.extend(pad_indextokens)
        attention_mask_pad = [0] * (self.max_sentence_length - len(attention_mask))
        attention_mask.extend(attention_mask_pad)
        return input_ids,attention_mask,seq_len

    def convert_into_indextokens_and_segment_id_nopadding(self,text):
        text = text[:self.max_sentence_length - 2]
        text = '[CLS]' + '\t' + '\t'.join(text) + '\t' + '[SEP]'
        input_ids = self.tokenize(text, self.vaca)
        attention_mask = [1] * len(input_ids)
        seq_len = len(text)

        return input_ids, attention_mask, seq_len

    def __len__(self):
        if self.repeat == None:
            data_len = 10000000
        else:
            data_len = len(self.data_list)
        return data_len

    def __getitem__(self, item):
        src_input_ids = self.data_list[item][0]
        tar_input_ids = self.data_list[item][1]


        tgt_key_padding_mask = self.data_list[item][3]
        tar_len = self.data_list[item][5]

        src_input_ids,src_key_padding_mask, src_len = produce_noise(src_input_ids,self.max_sentence_length)

        src_key_padding_mask = src_key_padding_mask < 1


        return src_input_ids, tar_input_ids, src_key_padding_mask, tgt_key_padding_mask,src_len, tar_len

代码比较简单,就没有注释了,主要是关注produce_noise()函数的实现,其实也很简单,主要是思路上构建一个用于文本纠错的数据集。

import random
import torch
def produce_noise(token_ids,max_len):
    mode = random.randint(0,2)
    if mode == 0:
        token_ids = random_drop(token_ids)
    elif mode == 1:
        token_ids = random_replace(token_ids)
    else :
        token_ids = random_add(token_ids)
    if len(token_ids) > max_len:
        temp = len(token_ids) - max_len
        token_ids = token_ids[:len(token_ids)-temp-1]+ [token_ids[-1]]
        assert len(token_ids) <= max_len

    src_len = len(token_ids)
    mask = [1] * len(token_ids)

    pad_token_ids = [0] * (max_len - len(token_ids))
    token_ids.extend(pad_token_ids)


    pad_mask = [0]*(max_len-len(mask))
    mask.extend(pad_mask)

    token_ids = torch.tensor(token_ids,dtype=torch.long)
    mask = torch.tensor(mask, dtype=torch.long)
    src_len = torch.tensor(src_len,dtype=torch.long)
    
    return token_ids,mask,src_len

def random_drop(token_ids):

    index = random.randint(1,len(token_ids)-2)
    new_token_ids = token_ids[:index]+token_ids[index+1:]


    return new_token_ids
def random_replace(token_ids):
    """
    把其中的字符一个替换成其他的字符
    :param token_ids:
    :return:
    """
    vaca = list(range(5993))
    diff = list(set(vaca).difference(set(token_ids)))

    index = random.randint(1, len(token_ids) - 2)
    diff_index = random.randint(0,len(diff)-1)

    token_ids[index] = diff[diff_index]
    return token_ids


def random_add(token_ids):
    index = random.randint(1, len(token_ids) - 2)
    new_token_ids = token_ids[:index+1] + [token_ids[index]]+ token_ids[index + 1:]
    return new_token_ids

模型代码就是前文实现的transformer模型代码,这里把训练代码放出来。

transformer文本纠错模型训练代码

import torch
import torch.nn as nn
from data_reader.DataReader_dynamic import DataReader_dynamic
from data_reader.DataReader import DataReader
import argparse
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import json
from einops import rearrange
import time
from progressbar import ProgressBar
from transformer_model.transformer import Transformer
# from transformers import BertTokenizer
import pandas as pd

"""
添加了collate_fn效果还差很多!
"""

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def greedy_search_decode(model, src,src_key_padding_mask, max_len:int = 64, start_symbol:int = 1):
    """
    :param model: Transformer model
    :param src: the encoder input
    :param max_len: 序列最大长度
    :return:ys 这个就是预测的具体序列
    解码的时候这几个mask是不能够少的
    """
    src_mask = gen_nopeek_mask(src.shape[1]).to(device)
    memory_key_padding_mask = src_key_padding_mask

    #最开始的字符[CLS]在词表的位置是1
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        tar_mask = gen_nopeek_mask(ys.shape[1]).to(device)
        out = model.forward(src, ys, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=None,src_mask=src_mask,tar_mask=tar_mask,memory_key_padding_mask=memory_key_padding_mask)
        #预测结果out,选取最后一个概率分布
        out = out[:,-1,:]
        #得到最大的那个概率的index,就是该次预测的字在词表的index
        _, next_word = torch.max(out, dim=1)
        next_word = next_word.data[0]
        if next_word != 2:
            #如果没有预测出终止符[SEP]
            #把这次预测的结果和以前的结果cat起来,再次循环迭代预测
            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        else:
            break
    return ys


def gen_nopeek_mask(length):
    """
     Returns the nopeek mask
             Parameters:
                     length (int): Number of tokens in each sentence in the target batch
             Returns:
                     mask (arr): tgt_mask, looks like [[0., -inf, -inf],
                                                      [0., 0., -inf],
                                                      [0., 0., 0.]]
     """
    mask = rearrange(torch.triu(torch.ones(length, length)) == 1, 'h w -> w h')
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

    return mask


def collate_fn(batch):
    src_input_ids, tar_input_ids, src_key_padding_mask, tgt_key_padding_mask, src_len, tar_len = map(torch.stack, zip(*batch))
    src_max_len = max(src_len).item()
    tar_max_len = max(tar_len).item()

    src_input_ids = src_input_ids[:,:src_max_len]
    tar_input_ids = tar_input_ids[:,:tar_max_len]

    src_key_padding_mask = src_key_padding_mask[:,:src_max_len]

    tgt_key_padding_mask = tgt_key_padding_mask[:,:tar_max_len]

    return src_input_ids, tar_input_ids, src_key_padding_mask, tgt_key_padding_mask


def train_epochs(args,device):
    model = Transformer()

    model.to(device)

    train_data = DataReader_dynamic(args, filename='train.csv')

    # train_loader = DataLoader(train_data,shuffle=True,batch_size=args.batch_size,collate_fn=collate_fn)
    train_loader = DataLoader(train_data, shuffle=True, batch_size=args.batch_size)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    for epoch in tqdm(range(args.epochs)):

        pbar = ProgressBar(n_total=len(train_loader), desc='Training Iteraction')

        for step, batch in enumerate(train_loader):
            batch = tuple(t.to(device) for t in batch)
            src_input_ids, tar_input_ids, src_key_padding_mask, tgt_key_padding_mask = batch[0], batch[1], batch[2], batch[3]

            src_mask = gen_nopeek_mask(src_input_ids.shape[1]).to(device)

            # 把tar_inputs_ids输入从1到倒数第二个;输出从第一个到最后一个——因为添加了[cls]和[sep]
            tar_inputs, tar_outputs = tar_input_ids[:, :-1], tar_input_ids[:, 1:]
            tar_mask = gen_nopeek_mask(tar_inputs.shape[1]).to(device)

            memory_key_padding_mask = src_key_padding_mask.clone()

            tgt_key_padding_mask = tgt_key_padding_mask[:, :-1]
            output = model(src_input_ids, tar_inputs, src_key_padding_mask, tgt_key_padding_mask, src_mask, tar_mask,
                           memory_key_padding_mask)

            # print(output.contiguous().view(-1,output.size(-1)).shape)
            loss = criterion(output.contiguous().view(-1, output.size(-1)), tar_outputs.contiguous().view(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            model.zero_grad()
            pbar(step, {"loss": loss.item()})

        torch.save(model, 'train_model/transformers_text_denoise_v2.bin')



def evaluation(device,vaca):
    model = torch.load('train_model/transformers_text_denoise_v2.bin')

    dev_data = DataReader(args, filename='dev_noise.csv')

    dev_loader = DataLoader(dev_data, shuffle=False, batch_size=1)

    model.eval()

    predicts = []

    with torch.no_grad():
        for step, batch in tqdm(enumerate(dev_loader)):
            batch = tuple(t.to(device) for t in batch)
            src_input_ids, src_key_padding_mask = batch[0], batch[2]

            out = greedy_search_decode(model, src_input_ids, src_key_padding_mask)

            if out.shape[1] > 1:
                out = out.squeeze().cpu().tolist()
            else:
                out = [out.item()]

            for i in range(len(out) - 1, -1, -1):
                if out[i] == 1 or out[i] == 2 or out[i] == 0:
                    out.remove(out[i])

            out_tokens = tokenize(out,vaca)

            out_str = ''.join(out_tokens)

            if step < 20:
                print('out_str', out_str)

            predicts.append(out_str)

    print('len(df)',len(df))
    print('len(predicts)',len(predicts))


    df['predict'] = predicts

    df.to_csv('result_transformer_v2.csv',index= False,sep = '\t')





def tokenize(ids,vaca):
    tokens = [ vaca[str(id)] for id in ids]
    return tokens


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir',default = 'dataset',type = str)
    parser.add_argument('--epochs', default = 5, type = int)
    parser.add_argument('--max_grad_norm',default = 1.0,type = float)
    parser.add_argument('--bert_model_path',default='Chinese-BERT-wwm',type = str)
    parser.add_argument('--batch_size',default=10 ,type = int)
    args = parser.parse_args()


    df = pd.read_csv('dataset/dev_noise.csv', sep='\t')
    print(df[0:20])

    # with open(args.bert_model_path+'/config.json','r') as f:
    #     dic = json.load(f)
    # args.vocab_size = dic['vocab_size']
    args.d_model = 768
    print(args)

    with open('chinese_vacabulary/id_to_token.json', 'r') as f:
        vaca = json.load(f)

    train_epochs(args,device)
    evaluation(device,vaca)

可以看看模型预测的效果:

看到出现了很多[OOV]这个是自己建立的词典没有收录完训练集中所有的字;直接采用bert模型提供的词典和tokenizer就不会出现这些问题——已验证。只不过是bert的词典长度为21128,有点大可能对推理速度有一定的影响。

 

这里简单的对文本纠错进行了一些尝试,没有做模型性能的评估,可以采用ROUGE和BLEU评价指标来自动话的评估模型生成的文本的质量。后续有机会可以把这两种自动化评估方法做一个实现。

over!

 

 

加油做一个快乐的炼丹师!

参考文章:

     图解Transformer(完整版)

     深入理解transformer源码

     Transformers中Batch Beam Search实现

  • 4
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 20
    评论
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值