模型结构图:

前向传播简化图(省略中间的一些linear层以及coverage):

代码与输入输出:
import sys
import os
import args
import torch
sys.path.append('.')
from model_dir.modeling import BertPreTrainedModel,BertModel
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformer.models import Decoder,Decoder_late
from dataset.dataloader import Dureader
from dataset.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
class BertForQuestionAnswering(BertPreTrainedModel):
def __init__(self,config):
super(BertForQuestionAnswering,self).__init__(config)
self.bert1 = BertModel(config)
self.bert1_linear = nn.Sequential(
nn.Linear(config.hidden_size,config.hidden_size),
nn.ReLU(True)
)
self.bert2dec_ques = nn.Linear(config.hidden_size,args.d_model)
self.bert2dec = nn.Linear(config.hidden_size,args.d_model)
# decoder
self.decoder_ques = Decoder(args.n_layers,args.d_k,args.d_v,args.d_model,args.d_ff,args.n_heads,args.max_ans_length,args.tgt_vocab_size,
args.dropout,args.weighted_model)
self.decoder = Decoder_late(args.n_layers,args.d_k,args.d_v,args.d_model,args.d_ff,args.n_heads,args.dropout,args.weighted_model)
self.tgt_proj = nn.Linear(args.d_model,args.tgt_vocab_size,bias = False)
self.sigmoid = nn.Sigmoid()
self.p = torch.nn.Linear(args.d_model * 2,1)
self.loss_fct = CrossEntropyLoss()
def forward(self,input_ids,input_ids_q,token_type_ids = None,can_answer = None,attention_mask = None,attention_mask_q = None,dec_inputs = None,dec_inputs_len = None,dec_targets = None,coverage = None):
sequence_output,first_token,_ = self.bert1(input_ids,token_type_ids = token_type_ids,attention_mask = attention_mask,output_all_encoded_layers = False)
sequence_output = self.bert1_linear(sequence_output) # [batch_size,doc+q_len,model_dim]
ques,ques_first,ques_last = self.bert1(input_ids_q,None,attention_mask = attention_mask_q,output_all_encoded_layers = False)
ques = self.bert2dec_ques(ques) # [batch_size,q_len,model_dim]
dec_outputs_ques, dec_self_attns_ques, dec_enc_attns_ques = self.decoder_ques(dec_inputs, dec_inputs_len, input_ids_q, ques, return_attn=True, is_initial=True)
# dec_outputs_ques: [batch_size,answer_len,model_dim]
# dec_self_attns_ques:[num_layers,batch_size,num_heads,answer_len,answer_len]
# dec_enc_attns_ques:[num_layers,batch_size,num_heads,answer_len,q_len]
sequence_output = self.bert2dec(sequence_output) # [batch_size,doc+q,d_model]
dec_outputs,dec_self_attns,dec_enc_attns = self.decoder(
dec_inputs,dec_inputs_len,dec_outputs_ques,input_ids,sequence_output,return_attn = True,is_initial = False
)
# dec_outputs:[batch_size,a_len,d_model]
# dec_self_attns:[None,None,None,None]
# dec_enc_attns:[num_layers,batch_size,num_heads,answer_len,doc+q_len]
dec_logits = self.tgt_proj(dec_outputs) # [batch_size,a_len,vocab_size]
# coverage
dec_enc_attn = dec_enc_attns[-1][:,-1,:,:] # [batch_size,answer_len,doc+q_len]
contextual = torch.bmm(dec_enc_attn,sequence_output) # [batch_size,answer_len,model_dim]
coverage = dec_enc_attn[:,0,:] # [batch_size,doc+q_len]
coverage_loss = torch.zeros([1]).cpu()
attn_values = torch.zeros([dec_enc_attn.size()[0],dec_enc_attn.size()[1],self.config.vocab_size]).cpu() # [batch_size,answer_len,vocab_size]
# 从源端复制,t = 0
index = input_ids # [batch_size,doc+q_len]
attn = dec_enc_attn[:,0,:] # [batch_size,doc+q_len]
attn_value = torch.zeros([attn.size()[0],self.config.vocab_size]).cpu() # [batch_size,vocab_size]
attn_value = attn_value.scatter_(1,index,attn) # [batch_size,vocab_size]
attn_values[:,0,:] = attn_value # [batch_size,answer_len,vocab_size]
# 从t = 1开始遍历
for i in range(1,dec_enc_attn.size()[1]):
current_att = dec_enc_attn[:,i,:] # [batch_size,doc+q_len]
coverage_loss = coverage_loss + torch.sum(torch.min(current_att.reshape(-1,1),coverage.reshape(-1,1)),0)
coverage += dec_enc_attn[:,i,:] # [batch_size,doc+q_len]
# 从源端复制
index = input_ids # [batch_size,doc+q_len]
attn = dec_enc_attn[:,i,:] # [batch_size,doc+q_len]
attn_value = torch.zeros([attn.size()[0],self.config.vocab_size]).cpu() # [batch_size,vocab_size]
attn_value = attn_value.scatter_(1,index,attn) # [batch_size,vocab_size]
attn_values[:,i,:] = attn_value # [batch_size,answer_len,vocab_size]
embedding = dec_outputs_ques #[batch_size,answer_len,model_dim]
contextual = contextual # [batch_size,answer_len,model_dim]
p = self.sigmoid(self.p(torch.cat([embedding,contextual],-1)).squeeze(-1)).unsqueeze(-1) # [batch_size,answer_len,1]
final_dist = (1 - p) * dec_logits + p * attn_values # [batch_size,answer_len,vocab_size]
final_dist = final_dist.view(-1,dec_logits.size(-1)) # [batch_size,answer_len*vocab_size]
step_loss = self.loss_fct(final_dist,dec_targets.contiguous().view(-1)) # [1]
return step_loss,coverage_loss
model = BertForQuestionAnswering.from_pretrained('./chinese_roberta_wwm_ext',
cache_dir = os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),'distributed_{}'.format(-1))).cpu()
data = Dureader()
batch = list(data.dev_iter)[0]
input_ids,input_mask,input_ids_q,input_mask_q,answer_ids,answer_mask,\
segment_ids,can_answer = \
batch.input_ids,batch.input_mask,batch.input_ids_q,batch.input_mask_q,\
batch.answer_ids,batch.answer_mask,batch.segment_ids,batch.can_answer
device = 'cpu'
coverage = torch.zeros([args.batch_size,args.max_seq_length]).to(device)
answer_inputs = answer_ids[:,:-1]
answer_targets = answer_ids[:,1:]
answer_len = answer_mask.sum(1) - 1
model(input_ids, input_ids_q, token_type_ids=segment_ids,
attention_mask=input_mask, attention_mask_q=input_mask_q,
dec_inputs=answer_inputs, dec_inputs_len=answer_len,
dec_targets=answer_targets, can_answer=can_answer, coverage=coverage)
输出:(tensor(26.5820, grad_fn=), tensor([47.9598], grad_fn=))

14万+

被折叠的 条评论
为什么被折叠?



