Transformer 语言模型输入输出

一、目录

  1. 理论知识
    1.2 为什么decoder 训练时输入=输出,而不是一个一个地循环训练。
  2. 模型
  3. transformer gtp2 模型接口 输入输出
  4. transformer 语言模型接口
  5. transformer 官方examples

二、实现

1 理论
在这里插入图片描述注意: 矩阵应该是 mask 矩阵 x value 矩阵。
模型生成时,青色区域为掩码区域,输入 I am a student , 第一时刻输出I 第二时刻输出am 第三时刻 输出a student 第四时刻输出, 即 第一行相当于第一时刻,第二行相当于第二时刻,第三行相当于第三时刻。 因为青色为掩码,即当前时刻看不到后面的词。如 第一时刻,左图与右图发生点成,掩码部分变为0 。
在这里插入图片描述
mask attention 注意力解读
在这里插入图片描述
在这里插入图片描述
2、模型
模块1. 多头注意力

import torch
import torch.nn as nn
import torch.nn.functional as F
'''点乘注意力'''
class ScaledDotProductAttention(nn.Module):
    def __init__(self,temperature,attn_dropout=0.1):
        super(ScaledDotProductAttention,self).__init__()
        self.temperature=temperature
        self.dropout=nn.Dropout(attn_dropout)
    def forward(self,q,k,v,mask=None):
        attn=torch.matmul(q/self.temperature,k.permute(0,1,3,2))
        if mask!=None:
            attn=attn.masked_fill(mask==0,-1e9)
        attn=self.dropout(F.softmax(attn,dim=-1))
        output=torch.matmul(attn,v)
        return output,attn

'''多头注意力'''
class MultiHeadAttention(nn.Module):
    def __init__(self,n_head,d_model,d_k,d_v,dropout=0.1):
        super(MultiHeadAttention,self).__init__()
        self.n_head=n_head        #注意力头数
        self.d_k=d_k             #key 每一头的维度
        self.d_v=d_v             #value 每一头的维度
        #维度转换,保证对齐     q,k,v 输入维度均为d_model,  输出维度n_head*d_k/d_k/d_v
        self.w_qs=nn.Linear(d_model,n_head*d_k,bias=False)
        self.w_ks=nn.Linear(d_model,n_head*d_k,bias=False)
        self.w_vs=nn.Linear(d_model,n_head*d_v,bias=False)

        self.fc=nn.Linear(n_head*d_v,d_model,bias=False)

        #单头注意力
        self.attention=ScaledDotProductAttention(temperature=d_k**0.5)

        self.dropout=nn.Dropout(dropout)
        self.layer_norm=nn.LayerNorm(d_model,eps=1e-6)
    def forward(self,q,k,v,mask=None):
        '''
        :param q: 解码器output
        :param k: 编码器output
        :param v: 编码器output
        :param mask:
        :return:
        '''
        #维度转换
        d_k,d_v,n_head=self.d_k,self.d_v,self.n_head
        #batch, len 序列长度
        sz_b,len_q,len_k,len_v=q.size(0),q.size(1),k.size(1),v.size(1)
        #残差
        residual=q
        #维度转换
        q=self.w_qs(q).view(sz_b,len_q,n_head,d_k)   #d_k 解码器单头维度
        k=self.w_ks(k).view(sz_b,len_k,n_head,d_k)
        v=self.w_vs(v).view(sz_b,len_v,n_head,d_v)
        q,k,v=q.permute(0,2,1,3),k.permute(0,2,1,3),v.permute(0,2,1,3)
        if mask is not None:
            mask=mask.unsqueeze(1)
        q,attn=self.attention(q,k,v,mask=mask)
        #维度转换回来
        q=q.permute(0,2,1,3).contiguous().view(sz_b,len_q,-1)
        q=self.dropout(self.fc(q))
        #残差
        q+=residual
        #标准化
        q=self.layer_norm(q)
        return q,attn
  1. 模块2 前向神经网络
class PositionwiseFeedForward(nn.Module):
    def __init__(self,d_in,d_hid,dropout=0.1):
        super(PositionwiseFeedForward,self).__init__()
        self.w1=nn.Linear(d_in,d_hid)
        self.w2=nn.Linear(d_hid,d_in)
        self.layer_norm=nn.LayerNorm(d_in,eps=1e-6)    #单一值时为最后一维的归一化
        self.dropout=nn.Dropout(dropout)
    def forward(self,x):
        residual=x
        x=self.w2(F.relu(self.w1(x)))
        x=self.dropout(x)
        x+=residual
        x=self.layer_norm(x)
        return x
  1. 模块3 编码器层 解码器层
class EncoderLayer(nn.Module):
    def __init__(self,d_model,d_inner,n_head,d_k,d_v,dropout=0.1):
        super(EncoderLayer,self).__init__()
        self.slf_attn=MultiHeadAttention(n_head=n_head,d_model=d_model,d_k=d_k,d_v=d_v,
                                         dropout=dropout)
        self.pos_ffn=PositionwiseFeedForward(d_in=d_model,d_hid=d_inner,dropout=dropout)

    def forward(self,enc_input,slf_attn_mask=None):
        enc_output,enc_slf_attn=self.slf_attn(enc_input,enc_input,enc_input,mask=slf_attn_mask)
        enc_output=self.pos_ffn(enc_output)
        return enc_output,enc_slf_attn

'''解码器层'''
class DecoderLayer(nn.Module):
    def __init__(self,d_model,d_inner,n_head,d_k,d_v,dropout=0.1):
        super(DecoderLayer,self).__init__()
        self.slf_attn=MultiHeadAttention(n_head=n_head,d_model=d_model,d_k=d_k,d_v=d_v,dropout=dropout)
        self.enc_attn=MultiHeadAttention(n_head=n_head,d_model=d_model,d_k=d_k,d_v=d_v,dropout=dropout)
        self.pos_ffn=PositionwiseFeedForward(d_in=d_model,d_hid=d_inner,dropout=dropout)
    def forward(self,dec_input,enc_input,slf_attn_mask=None,dec_enc_attn_mask=None):
        #多头自注意力     #dec_input=[batch,trg_len,hidden]
        dec_output, dec_slf_attn=self.slf_attn(dec_input,dec_input,dec_input,mask=slf_attn_mask)     #基于掩码的自注意力
        #多头注意力模型
        dec_output, dec_enc_attn=self.enc_attn(q=dec_output,k=enc_input,v=enc_input,mask=dec_enc_attn_mask)
        dec_output=self.pos_ffn(dec_output)
        return dec_output,dec_slf_attn,dec_enc_attn
  1. 模型
'''位置编码'''
class PositionalEncoding(nn.Module):
    def __init__(self,d_hid,n_postion=200):
        super(PositionalEncoding,self).__init__()
        self.register_buffer("pos_table",self._get_sinusoid_encoding_table(n_postion,d_hid))
    def _get_sinusoid_encoding_table(self,n_postion,d_hid):
        ''' Sinusoid position encoding table '''
        # TODO: make it with torch instead of numpy
        def get_position_angle_vec(position):
            return [position/np.power(10000,2*(hid_j//2)/d_hid) for hid_j in range(d_hid)]
        sinusoid_table=np.array([get_position_angle_vec(pos_i) for pos_i in range(n_postion)])
        sinusoid_table[:,0::2]=np.sin(sinusoid_table[:,0::2])    #dim 2i
        sinusoid_table[:,1::2]=np.cos(sinusoid_table[:,1::2])    #dim 2i+1
        return torch.FloatTensor(sinusoid_table).unsqueeze(0)
    def forward(self,x):
        return x+self.pos_table[:,:x.size(1)].clone().detach()

'''编码器'''
class Encoder(nn.Module):
    ''' 自注意力机制的编码器 '''
    def __init__(self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
            d_model, d_inner, pad_idx, dropout=0.1, n_position=200):
        super(Encoder,self).__init__()
        self.src_word_emb=nn.Embedding(n_src_vocab,d_word_vec,padding_idx=pad_idx)
        #位置编码
        self.position_enc=PositionalEncoding(d_word_vec,n_postion=n_position)
        self.dropout=nn.Dropout(dropout)
        #编码层
        self.layer_stack=nn.ModuleList([EncoderLayer(d_model=d_model,d_inner=d_inner,
                         n_head=n_head,d_k=d_k,d_v=d_v,dropout=dropout) for _ in range(n_layers)])
        self.layer_norm=nn.LayerNorm(d_model,eps=1e-6)

    def forward(self,src_seq,src_mask,return_attns=False):
        enc_slf_attn_list=[]
        enc_output=self.dropout(self.position_enc(self.src_word_emb(src_seq)))
        enc_output=self.layer_norm(enc_output)
        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn =enc_layer(enc_output,slf_attn_mask=src_mask)
            enc_slf_attn_list+=[enc_slf_attn] if return_attns else []
        if return_attns:
            return enc_output,enc_slf_attn_list
        return enc_output,
#解码器
class Decoder(nn.Module):
    def __init__(self, n_trg_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
            d_model, d_inner, pad_idx, n_position=200, dropout=0.1):
        super(Decoder,self).__init__()
        self.trg_word_emb=nn.Embedding(n_trg_vocab,d_word_vec,padding_idx=pad_idx)
        self.position_enc=PositionalEncoding(d_word_vec,n_postion=n_position)
        self.dropout=nn.Dropout(dropout)
        self.layer_stack=nn.ModuleList([DecoderLayer(d_model=d_model,d_inner=d_inner,n_head=n_head,
                               d_k=d_k,d_v=d_v,dropout=dropout) for _ in range(n_layers)])
        self.layer_norm=nn.LayerNorm(d_model,eps=1e-6)

    def forward(self,trg_seq,trg_mask,enc_output,src_mask,return_attns=False):
        dec_slf_attn_list,dec_enc_attn_list=[],[]
        dec_output=self.dropout(self.position_enc(self.trg_word_emb(trg_seq)))
        dec_output=self.layer_norm(dec_output)
        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(dec_output,enc_output,
                                                               slf_attn_mask=trg_mask,dec_enc_attn_mask=src_mask)
            dec_slf_attn_list+=[dec_slf_attn] if return_attns else []
            dec_enc_attn_list+=[dec_enc_attn] if return_attns else []
        if return_attns:
            return dec_output,dec_slf_attn_list,dec_enc_attn_list
        return dec_output,

class Transformer(nn.Module):
    ''' A sequence to sequence model with attention mechanism. '''
    def __init__(self,n_src_vocab,n_trg_vocab,src_pad_idx,trg_pad_idx,d_word_vec=512,
                 d_model=512,d_inner=2048,n_layers=6,n_head=8,d_k=64,d_v=64,dropout=0.1,
                 n_position=200,trg_emb_prj_weight_sharing=True, emb_src_trg_weight_sharing=True):
        super(Transformer,self).__init__()

        self.src_pad_idx,self.trg_pad_idx=src_pad_idx,trg_pad_idx
        self.encoder=Encoder(n_src_vocab=n_src_vocab, d_word_vec=d_word_vec,
                             n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,d_model=d_model,
                             d_inner=d_inner, pad_idx=src_pad_idx, dropout=dropout, n_position=n_position)

        self.decoder=Decoder( n_trg_vocab=n_trg_vocab, d_word_vec=d_word_vec, n_layers=n_layers,
                              n_head=n_head, d_k=d_k, d_v=d_v,d_model=d_model, d_inner=d_inner,
                              pad_idx=trg_pad_idx, n_position=n_position, dropout=dropout)

        self.trg_word_prj=nn.Linear(d_model,n_trg_vocab,bias=False)
        #参数初始化
        for p in self.parameters():
            if p.dim()>1:
                nn.init.xavier_uniform_(p)
        assert d_model==d_word_vec
        self.x_logit_scale=1.0
        # if trg_emb_prj_weight_sharing:
        #     self.trg_word_prj.weight=self.decoder.trg_word_emb.weight
        #     self.x_logit_scale=(d_model**-0.5)
        # if emb_src_trg_weight_sharing:      #词嵌入共享
        #     self.encoder.src_word_emb.weight=self.decoder.trg_word_emb.weight

        #src_seq=[batch,src_seq]     trg_seq=[batch,trg_seq]
    def forward(self,src_seq,trg_seq):
        src_mask=get_pad_mask(src_seq,self.src_pad_idx)
        trg_mask=get_pad_mask(trg_seq,self.trg_pad_idx) &get_subsequent_mask(trg_seq)
        #src_mask:掩码,即非pad 为true   [batch,1,src_seq]
        #trg_mask: 掩码与下三角矩阵的的交集     [batch,trg_seq,trg_seq]  中间一维度复制而得,对应时刻
        enc_output,*_=self.encoder(src_seq,src_mask)
        dec_output,*_=self.decoder(trg_seq,trg_mask,enc_output,src_mask)
        seq_logit=self.trg_word_prj(dec_output)*self.x_logit_scale         #全连接映射层
        return seq_logit
def get_pad_mask(seq,pad_idx):
    return (seq!=pad_idx).unsqueeze(-2)

def get_subsequent_mask(seq):
    sz_b,len_s=seq.size()
    subsequent_mask=(1-torch.triu(torch.ones((1,len_s,len_s),device=seq.device),diagonal=1)).bool()
    return subsequent_mask         #包含对接线的下三角矩阵
  1. Beam Search 预测
# beam search 方法进行序列预测
class Translator(nn.Module):
    def __init__(self,model,beam_size,max_seq_len,src_pad_idx,trg_pad_idx,trg_bos_idx,trg_eos_idx):
        super(Translator,self).__init__()
        self.alpha=0.7
        self.beam_size=beam_size
        self.max_seq_len=max_seq_len
        self.src_pad_idx=src_pad_idx
        self.trg_pad_idx=trg_pad_idx
        self.trg_bos_idx=trg_bos_idx
        self.trg_eos_idx=trg_eos_idx

        self.model=model
        self.model.eval()
        self.register_buffer("init_seq",torch.LongTensor([[trg_bos_idx]]))
        self.register_buffer("blank_seqs",torch.full((beam_size,max_seq_len),trg_pad_idx,dtype=torch.long))
        self.blank_seqs[:,0]=self.trg_bos_idx
        self.register_buffer("len_map",torch.arange(1,max_seq_len+1,dtype=torch.long).unsqueeze(0))

    def _get_init_state(self,src_seq,src_mask):
        beam_size=self.beam_size
        enc_output,*_=self.model.encoder(src_seq,src_mask)
        dec_output=self._model_decode(self.init_seq,enc_output,src_mask)

        best_k_probs,best_k_idx=dec_output[:,-1,:].topk(beam_size)
        scores=torch.log(best_k_probs).view(beam_size)
        gen_seq=self.blank_seqs.clone().detach()
        gen_seq[:,1]=best_k_idx[0]
        enc_output=enc_output.repeat(beam_size,1,1)
        return enc_output,gen_seq,scores
    #解码器解码
    def _model_decode(self,trg_seq,enc_output,src_mask):
        trg_mask=get_subsequent_mask(trg_seq)
        dec_output,*_=self.model.decoder(trg_seq,trg_mask,enc_output,src_mask)
        dec_output=F.softmax(self.model.trg_word_prj(dec_output),dim=-1)
        return dec_output
    #beam search
    def _get_best_score_and_idx(self,gen_seq,dec_output,scores,step):
        assert len(scores.size())==1
        beam_size=self.beam_size
        #beam_size-->beam_size=beam_size*beam_size
        best_k2_probs,best_k2_idx=dec_output[:,-1,:].topk(beam_size)
        #[beam_size,beam_size]

        scores=torch.log(best_k2_probs).view(beam_size,-1)+scores.view(beam_size,1)
        #前beam 组数据,以及对应的idx
        scores,best_k_idx_in_k2=scores.view(-1).topk(beam_size)

        #找到原数据对应的坐标
        best_k_r_idxs,best_k_c_idxs=best_k_idx_in_k2//beam_size, best_k_idx_in_k2%beam_size
        #最后一步id
        best_k_idx=best_k2_idx[best_k_r_idxs,best_k_c_idxs]

        gen_seq[:,:step]=gen_seq[best_k_r_idxs,:step]
        gen_seq[:,step]=best_k_idx
        return gen_seq,scores


    #beam search 解码    入口
    def translate_sentence(self,src_seq):
        assert src_seq.size(0)==1
        src_pad_idx,trg_eos_idx=self.src_pad_idx,self.trg_eos_idx
        max_seq_len,beam_size,alpha=self.max_seq_len,self.beam_size,self.alpha

        with torch.no_grad():
            src_mask=get_pad_mask(src_seq,src_pad_idx)
            #状态初始化
            enc_output,gen_seq,scores=self._get_init_state(src_seq,src_mask)
            ans_idx=0
            for step in range(2,max_seq_len):
                dec_output=self._model_decode(gen_seq[:,:step],enc_output,src_mask)
                gen_seq,scores=self._get_best_score_and_idx(gen_seq,dec_output,scores,step)

                eos_locs=(gen_seq==trg_eos_idx)    #判定是否出现eos

                seq_lens,_=self.len_map.masked_fill(~eos_locs,max_seq_len).min(1)   #每一句话eos 出现的位置

                if(eos_locs.sum(1)>0).sum(0).item()==beam_size:
                    _,ans_idx=scores.div(seq_lens.float()**alpha).max(0)   #得分最高的行
                    ans_idx=ans_idx.item()
                    break

        #gen_seq:[5,seq_len]
        return gen_seq[ans_idx][:seq_lens[ans_idx]].tolist()
  1. 训练
def main():
    args=args_parse()
    src_vocab,trg_vocab=MyProcessor().get_src_trg_vocab(datasets=MyProcessor.get_data(args.train_file),
                                            src_vocab=args.src_vocab,trg_vocab=args.trg_vocab)

    collator=MyCollator(src_vocab,trg_vocab,max_len=args.max_len,min_len=args.min_len)
    train_loader,eval_loader=make_loader(collator,args.train_file,args.eval_file,batch_size=args.batch_size)
    print(len(src_vocab))
    model=Transformer(n_src_vocab=len(src_vocab),n_trg_vocab=len(trg_vocab),src_pad_idx=src_vocab[PAD_TOKEN],
                      trg_pad_idx=trg_vocab[PAD_TOKEN],d_word_vec=args.d_word_vec,
                      d_model=args.d_model,d_inner=args.d_inner,n_layers=args.n_layers,n_head=args.n_head,
                      d_k=args.d_k,d_v=args.d_v,dropout=args.dropout,n_position=args.n_position,
                      trg_emb_prj_weight_sharing=args.trg_emb_prj_weight_sharing,
                      emb_src_trg_weight_sharing=args.emb_src_trg_weight_sharing
                      )
    model.to(args.device)
    train(model,args,train_loader,eval_loader,trg_pad_id=trg_vocab[PAD_TOKEN],model_name="transformer")

def train(model,args,train_loader,eval_loader,trg_pad_id,model_name):
    criterion=nn.CrossEntropyLoss(ignore_index=trg_pad_id)
    optimizer=torch.optim.Adam(model.parameters(),lr=args.lr)
    if os.path.exists(args.log):
        shutil.rmtree(args.log)
    writer=SummaryWriter(args.log)
    step=0
    best_loss=1e3
    for epoch in range(args.max_epoch):
        model.train()
        for id,batch in enumerate(train_loader):
            src_tensor,trg_tensor=batch
            src_tensor.to(args.device)
            trg_tensor.to(args.device)
            labels=trg_tensor[:,1:]

            logits=model(src_tensor,trg_tensor[:,:-1])
            loss=criterion(logits.view(-1,logits.size(-1)),labels.reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step+=1
            if id%args.log_step==0:
                not_ignore=labels.ne(trg_pad_id)
                num_targets=not_ignore.long().sum().item()
                correct=(torch.max(logits,dim=-1)[1].cpu().data==labels.cpu().data)&not_ignore
                acc=100*correct.sum()/num_targets
                logging.info("train epoch {} id {} loss {} acc {}".format(epoch,id,loss.item(),acc.item()))
                writer.add_scalar("train_acc",acc.item(),step)
                writer.add_scalar("train_loss",loss.item(),step)

                if loss.item()<best_loss:
                    best_loss=loss.item()
                    path=os.path.join(args.save,"transformer.pt")
                    if not os.path.exists(args.save):
                        os.makedirs(args.save,exist_ok=True)
                    #torch.save(model.state_dict(),path)
                    logging.info(path)

            if step%args.eval_step==0:
                acc_,loss_=valid(model,args,eval_loader,trg_pad_id)
                logging.info("eval step {} loss {} acc {}".format(epoch, loss_, acc_))
                writer.add_scalar("eval_acc", acc_, step)
                writer.add_scalar("eval_loss", loss_, step)
    return

def valid(model,args,eval_loader,trg_pad_id):
    model.eval()
    criterion=nn.CrossEntropyLoss(ignore_index=trg_pad_id)
    with torch.no_grad():
        total_loss=0
        total_nums=0
        total_correct=0
        for id,batch in enumerate(eval_loader):
            src_tensor,trg_tensor=batch
            src_tensor.to(args.device)
            trg_tensor.to(args.device)
            labels=trg_tensor[:,1:]
            logits=model(src_tensor,trg_tensor[:,:-1])
            loss=criterion(logits.view(-1,logits.size(-1)),labels.reshape(-1))
            total_loss+=loss.item()
            not_ignore=labels.ne(trg_pad_id)
            num_targets=not_ignore.long().sum().item()
            total_nums+=num_targets
            correct=((torch.max(logits,dim=-1)[1].cpu().data==labels.cpu().data)&not_ignore).sum()
            total_correct+=correct.item()
        return 100*total_correct/total_nums,total_loss/id
  1. 预测
class Inference():
    def __init__(self):
        args = args_parse()
        self.args=args
        src_vocab, trg_vocab = MyProcessor().get_src_trg_vocab(datasets=MyProcessor.get_data(args.train_file),
                                                               src_vocab=args.src_vocab, trg_vocab=args.trg_vocab)
        self.trg_id2vocab={v:k for k,v in trg_vocab.items()}
        self.collator = MyCollator(src_vocab, trg_vocab, max_len=args.max_len, min_len=args.min_len)
        self.model = Transformer(n_src_vocab=len(src_vocab), n_trg_vocab=len(trg_vocab), src_pad_idx=src_vocab[PAD_TOKEN],
                            trg_pad_idx=trg_vocab[PAD_TOKEN], d_word_vec=args.d_word_vec,
                            d_model=args.d_model, d_inner=args.d_inner, n_layers=args.n_layers, n_head=args.n_head,
                            d_k=args.d_k, d_v=args.d_v, dropout=args.dropout, n_position=args.n_position,
                            trg_emb_prj_weight_sharing=args.trg_emb_prj_weight_sharing,
                            emb_src_trg_weight_sharing=args.emb_src_trg_weight_sharing
                            )
        path = os.path.join(args.save, "transformer.pt")
        checkpoint=torch.load(path,map_location="cpu")
        self.model.load_state_dict(checkpoint)
        self.model.to(args.device)
        self.model.eval()
        self.translator=Translator(model=self.model,beam_size=args.beam_size,max_seq_len=args.max_len,
                                   src_pad_idx=src_vocab[PAD_TOKEN],trg_pad_idx=trg_vocab[PAD_TOKEN],trg_bos_idx=trg_vocab[SOS_TOKEN],
                                   trg_eos_idx=trg_vocab[EOS_TOKEN])
    def forward(self,text_list):
        result=[]
        for text in text_list:
            src_tensor=self.collator.convert_to_tensor(text)
            src_tensor.to(self.args.device)
            output=self.translator.translate_sentence(src_tensor)
            predict="".join([self.trg_id2vocab[w] for w in output][1:-1])
            result.append(predict)
        return result
if __name__ == '__main__':
    m=Inference()
    l=['老是较书。',
         '感谢等五分以后,碰到一位很棒的奴生跟我可聊。',
         '遇到一位很棒的奴生跟我聊天。',
         '遇到一位很美的女生跟我疗天。',
         '他们只能有两个选择:接受降新或自动离职。',
         '王天华开心得一直说话。']
    res=m.forward(l)
    r=list(zip(l,res))
    for w in r:
        print("输入:",w[0])
        print("输出:", w[1])
        print()

三、gtp2 接口使用
gpt2 模型+语言模型层, 为transformer 的解码器部分

    {
  "initializer_range": 0.02,    
  "layer_norm_epsilon": 1e-05,
  "n_ctx": 300,                #序列最大长度
  "n_embd": 256,               #embed
  "n_head": 8,
  "n_layer": 6,
  "n_positions": 300,          #序列最大长度
  "vocab_size": 1905           #词汇大小
}
from transformers import GPT2Tokenizer, GPT2Model,GPT2LMHeadModel
import transformers
def create_model(gpt_config,vocab_size,pretrained_model):
    if pretrained_model:
        model=GPT2LMHeadModel.from_pretrained(pretrained_model)
    else:
        model_config=transformers.GPT2Config.from_json_file(gpt_config)
        model=GPT2LMHeadModel(config=model_config)
        model.resize_token_embeddings(vocab_size)
    return model,model.config.to_dict().get("n_ctx")  

huggingface API 介绍
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained(“gpt2”)
model = GPT2LMHeadModel.from_pretrained(“gpt2”)
inputs = tokenizer(“Hello, my dog is cute”, return_tensors=“pt”)
outputs = model(**inputs, labels=inputs[“input_ids”])
loss = outputs.loss
logits = outputs.logits
训练部分:

def main():
    args=args_parse()
    processor=MyProcessor()
    trainSets=processor.get_pairs(args.train_file)
    evalSets=processor.get_pairs(args.eval_file)
    vocab2id=processor.make_vocab2id(trainSets,args.vocab2id_path)
    model,c_ctx=create_model(gpt_config=args.gpt_config,vocab_size=len(vocab2id),pretrained_model=args.pretrained_model)
    args.max_len=c_ctx
    model.to(args.device)
    logging.info("c_ctx {} ,vocab_size {}".format(c_ctx,len(vocab2id)))
    #参数
    num_parameters=0
    parameters=model.parameters()
    for parameter in parameters:
        num_parameters+=parameter.numel()
    logging.info("模型参数数量:{}".format(num_parameters))
    collator=MyCallator(vocab2id=vocab2id,max_len=args.max_len)
    train_loader,eval_loader=make_loader(collator,trainSets,evalSets,batch_size=args.batch_size)

    train(model,args,train_loader,eval_loader,model_name="gpt2",pad_id=vocab2id[PAD_TOKENS])

def train(model,args,train_loader,eval_loader,model_name,pad_id):
    criterion=nn.CrossEntropyLoss(ignore_index=pad_id,reduction="sum")
    optimizer=torch.optim.Adam(model.parameters(),lr=args.lr)
    writer=SummaryWriter(args.log)
    scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=100*args.max_epoch,
                                                         eta_min=0)
    step=0
    best_loss=1e3
    for epoch in range(args.max_epoch):
        model.train()
        for id,batch in enumerate(train_loader):
            input_ids=batch
            input_ids.to(args.device)
            output_ids=input_ids[:,1:].to(args.device)
            #输入: input_ids=[batch,seq_len]  输出为CausalLMOutputWithCrossAttentions or tuple 类
            output=model.forward(input_ids)  
            logits=output[0][:,:-1,:].contiguous()    #获取预测输出logits=[batch,seq_len,vocab_size]
            
            # input_ids=batch[:,:]  label_ids=[:,1:]
            #logits=batch[:,:-1]   将logits 与label 错开一位,进行预测。
            
            loss=criterion(logits.view(-1,logits.size(-1)),output_ids.reshape(-1))
            not_ignore=output_ids.ne(pad_id)
            num_targets=not_ignore.long().sum().item()
            loss=loss/num_targets

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),args.max_grad_norm)

            optimizer.step()
            step+=1
            if id%args.log_step==0:
                correct=(torch.max(logits,dim=-1)[1].cpu().data==output_ids.cpu().data)&not_ignore
                correct=correct.sum()
                acc=100*correct/num_targets
                logging.info("train epoch {} batch {} loss {} acc {}".format(epoch,id,loss.item(),acc.item()))
                writer.add_scalar("train_acc",acc.item(),step)
                writer.add_scalar("train_loss",loss.item(),step)
            if step%args.eval_step==0:
                scheduler.step()
                acc_,loss_=evalid(model,args,eval_loader,model_name,pad_id)
                logging.info("eval epoch {} step {} loss {} acc {}".format(epoch, step, loss_, acc_))
                writer.add_scalar("eval_acc", acc_, step)
                writer.add_scalar("eval_loss", loss_, step)
        cur_loss=loss.item()
        logging.info("cur loss {}".format(cur_loss))
        if cur_loss<best_loss:
            best_loss=cur_loss
            path=os.path.join(args.save_path,"gpt2.pt")
            #torch.save(model.state_dict(),path)
            logging.info(path)
def evalid(model,args,eval_loader,model_name,pad_id):
    model.eval()
    criterion=nn.CrossEntropyLoss(ignore_index=pad_id,reduction="sum")
    total_num_target=0
    total_loss=0
    total_correct=0
    with torch.no_grad():
        for id,batch in enumerate(eval_loader):
            input_ids=batch
            input_ids.to(args.device)
            label_ids=input_ids[...,1:]
            output=model.forward(input_ids)
            output=output[0]
            logits=output[:,:-1].contiguous()

            loss=criterion(logits.view(-1,logits.size(-1)),label_ids.reshape(-1))
            not_ignore=label_ids.ne(pad_id)
            num_target=not_ignore.long().sum().item()
            total_loss+=loss.item()
            total_num_target+=num_target

            correct=(torch.max(logits,dim=-1)[1].cpu().data==label_ids.cpu().data)&not_ignore
            correct=torch.sum(correct).item()
            total_correct+=correct
        acc=100*total_correct/total_num_target
        loss=total_loss/total_num_target
        return acc,loss

预测:

class Inference:
    def __init__(self):
        args = args_parse()
        processor = MyProcessor()
        vocab2id = processor.make_vocab2id(None, args.vocab2id_path)
        self.model, c_ctx = create_model(gpt_config=args.gpt_config, vocab_size=len(vocab2id),
                                    pretrained_model=args.pretrained_model)
        args.max_len = c_ctx
        path=os.path.join(args.save_path,"gpt2.pt")
        checkpoint=torch.load(path,map_location="cpu")
        self.model.load_state_dict(checkpoint)
        self.model.to(args.device)
        self.vocab2id=vocab2id
        self.max_len=100
        self.id2vocab={k:v for v,k in vocab2id.items()}
        self.unk_token_id=self.vocab2id[UNK_TOKENS]
        self.args=args
        print("model loaded")
    def predict(self,sentence):
        sentence=[SOS_TOKENS]+list(sentence.lower())+[SOS_TOKENS]
        sent_id=[self.vocab2id.get(w,self.unk_token_id) for w in sentence][:self.max_len]
        sent_tensor=torch.tensor([sent_id]).long()
        sent_tensor.to(self.args.device)
        for id in range(sent_tensor.size(-1),self.max_len):    #最大长度
            output=self.model.forward(sent_tensor)[0]          #输入input_ids,  输出logits           
            pred_id=torch.max(output[:,-1],dim=-1)[1]          #最后一个时刻预测
            pred_id=pred_id.unsqueeze(1)                       #(batch_size ,1)   
            sent_tensor=torch.cat([sent_tensor,pred_id],dim=1)   #拼接           
        sent_numpy=sent_tensor.cpu().data.numpy()[0,:]
        translate=[self.id2vocab[w] for w in sent_numpy]
        res=[]
        flag=0
        for w in translate:
            if w==SOS_TOKENS:
                flag+=1
            if w!=EOS_TOKENS:
                if flag>1:
                    res.append(w)
            else :
                break
        return "".join(res[1:])
if __name__ == '__main__':
    input = ['老是较书。',
             '感谢等五分以后,碰到一位很棒的奴生跟我可聊。',
             '遇到一位很棒的奴生跟我聊天。',
             '遇到一位很美的女生跟我疗天。',
             '他们只能有两个选择:接受降新或自动离职。',
             '王天华开心得一直说话。']

    model = Inference()
    output=[]
    for w in input:
        res=model.predict(w)
        output.append(res)
    for a, b in zip(input, output):
        print("input :", a)
        print("output: ", b)
        print()
  1. transformer 语言模型接口
    1. 生成式接口:transformers.generation.utils.generate 每一个token 循环输出。
  2. transformer 官方examples
    github:https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling
    doc:https://huggingface.co/docs/transformers/tasks/language_modeling
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值