pg-net训练模型部分模型部分

模型结构图:
在这里插入图片描述
前向传播简化图(省略中间的一些linear层以及coverage):
在这里插入图片描述

代码与输入输出:

import sys
import os
import args
import torch
sys.path.append('.')
from model_dir.modeling import BertPreTrainedModel,BertModel
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformer.models import Decoder,Decoder_late
from dataset.dataloader import Dureader
from dataset.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
class BertForQuestionAnswering(BertPreTrainedModel):
    def __init__(self,config):
        super(BertForQuestionAnswering,self).__init__(config)
        self.bert1 = BertModel(config)
        self.bert1_linear = nn.Sequential(
            nn.Linear(config.hidden_size,config.hidden_size),
            nn.ReLU(True)
        )
        self.bert2dec_ques = nn.Linear(config.hidden_size,args.d_model)
        self.bert2dec = nn.Linear(config.hidden_size,args.d_model)
        # decoder
        self.decoder_ques = Decoder(args.n_layers,args.d_k,args.d_v,args.d_model,args.d_ff,args.n_heads,args.max_ans_length,args.tgt_vocab_size,
                                   args.dropout,args.weighted_model)
        self.decoder = Decoder_late(args.n_layers,args.d_k,args.d_v,args.d_model,args.d_ff,args.n_heads,args.dropout,args.weighted_model)
        self.tgt_proj = nn.Linear(args.d_model,args.tgt_vocab_size,bias = False)
        self.sigmoid = nn.Sigmoid()
        self.p = torch.nn.Linear(args.d_model * 2,1)
        self.loss_fct = CrossEntropyLoss()
    def forward(self,input_ids,input_ids_q,token_type_ids = None,can_answer = None,attention_mask = None,attention_mask_q = None,dec_inputs = None,dec_inputs_len = None,dec_targets = None,coverage = None):
        sequence_output,first_token,_ = self.bert1(input_ids,token_type_ids = token_type_ids,attention_mask = attention_mask,output_all_encoded_layers = False)
        sequence_output = self.bert1_linear(sequence_output)           # [batch_size,doc+q_len,model_dim]
        ques,ques_first,ques_last = self.bert1(input_ids_q,None,attention_mask = attention_mask_q,output_all_encoded_layers = False)
        ques = self.bert2dec_ques(ques)    # [batch_size,q_len,model_dim]
        
        dec_outputs_ques, dec_self_attns_ques, dec_enc_attns_ques  = self.decoder_ques(dec_inputs, dec_inputs_len, input_ids_q, ques, return_attn=True, is_initial=True)
        # dec_outputs_ques: [batch_size,answer_len,model_dim]
        # dec_self_attns_ques:[num_layers,batch_size,num_heads,answer_len,answer_len]
        # dec_enc_attns_ques:[num_layers,batch_size,num_heads,answer_len,q_len]
        sequence_output = self.bert2dec(sequence_output)    # [batch_size,doc+q,d_model]
        dec_outputs,dec_self_attns,dec_enc_attns = self.decoder(
            dec_inputs,dec_inputs_len,dec_outputs_ques,input_ids,sequence_output,return_attn = True,is_initial = False
        )
        # dec_outputs:[batch_size,a_len,d_model]
        # dec_self_attns:[None,None,None,None]
        # dec_enc_attns:[num_layers,batch_size,num_heads,answer_len,doc+q_len]
        dec_logits = self.tgt_proj(dec_outputs)  # [batch_size,a_len,vocab_size]  
        # coverage
        dec_enc_attn = dec_enc_attns[-1][:,-1,:,:]    # [batch_size,answer_len,doc+q_len]
        contextual = torch.bmm(dec_enc_attn,sequence_output)    # [batch_size,answer_len,model_dim]
        coverage = dec_enc_attn[:,0,:]    # [batch_size,doc+q_len]
        coverage_loss = torch.zeros([1]).cpu()
        attn_values = torch.zeros([dec_enc_attn.size()[0],dec_enc_attn.size()[1],self.config.vocab_size]).cpu()    # [batch_size,answer_len,vocab_size]
        # 从源端复制,t = 0
        index = input_ids    # [batch_size,doc+q_len]
        attn = dec_enc_attn[:,0,:]   # [batch_size,doc+q_len]
        attn_value = torch.zeros([attn.size()[0],self.config.vocab_size]).cpu()    # [batch_size,vocab_size]
        attn_value = attn_value.scatter_(1,index,attn)    # [batch_size,vocab_size]
        attn_values[:,0,:] = attn_value    # [batch_size,answer_len,vocab_size]
        # 从t = 1开始遍历
        for i in range(1,dec_enc_attn.size()[1]):
            current_att = dec_enc_attn[:,i,:]    # [batch_size,doc+q_len]
            coverage_loss = coverage_loss + torch.sum(torch.min(current_att.reshape(-1,1),coverage.reshape(-1,1)),0)
            coverage += dec_enc_attn[:,i,:]    # [batch_size,doc+q_len]
            # 从源端复制
            index = input_ids   # [batch_size,doc+q_len]
            attn = dec_enc_attn[:,i,:]     # [batch_size,doc+q_len]
            attn_value = torch.zeros([attn.size()[0],self.config.vocab_size]).cpu()    # [batch_size,vocab_size]
            attn_value = attn_value.scatter_(1,index,attn)    # [batch_size,vocab_size]
            attn_values[:,i,:] = attn_value    # [batch_size,answer_len,vocab_size]
        embedding = dec_outputs_ques    #[batch_size,answer_len,model_dim]
        contextual = contextual     # [batch_size,answer_len,model_dim]
        p = self.sigmoid(self.p(torch.cat([embedding,contextual],-1)).squeeze(-1)).unsqueeze(-1)    # [batch_size,answer_len,1] 
        final_dist = (1 - p) * dec_logits + p * attn_values    # [batch_size,answer_len,vocab_size]
        final_dist = final_dist.view(-1,dec_logits.size(-1))    # [batch_size,answer_len*vocab_size]
        step_loss = self.loss_fct(final_dist,dec_targets.contiguous().view(-1))    # [1]
        return step_loss,coverage_loss

    
model = BertForQuestionAnswering.from_pretrained('./chinese_roberta_wwm_ext',
                                                 cache_dir = os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),'distributed_{}'.format(-1))).cpu()
data = Dureader()
batch = list(data.dev_iter)[0]
input_ids,input_mask,input_ids_q,input_mask_q,answer_ids,answer_mask,\
            segment_ids,can_answer = \
                batch.input_ids,batch.input_mask,batch.input_ids_q,batch.input_mask_q,\
                batch.answer_ids,batch.answer_mask,batch.segment_ids,batch.can_answer
device = 'cpu'

coverage = torch.zeros([args.batch_size,args.max_seq_length]).to(device)
answer_inputs = answer_ids[:,:-1]
answer_targets = answer_ids[:,1:]
answer_len = answer_mask.sum(1) - 1
model(input_ids, input_ids_q, token_type_ids=segment_ids,
                         attention_mask=input_mask, attention_mask_q=input_mask_q,
                         dec_inputs=answer_inputs, dec_inputs_len=answer_len,
                         dec_targets=answer_targets, can_answer=can_answer, coverage=coverage)

输出:(tensor(26.5820, grad_fn=), tensor([47.9598], grad_fn=))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>