解读ACL 2020的一篇paper (Recurrent Chunking Mechanisms for Long-text machine reading comprehension)的源码

本文的目的是解读Recurrent Chunking Mechanisms for Long-text machine reading comprehension这篇论文的GitHub上的代码。

我会在代码的基础上添加尽可能多的注释。

首先:

def read_coqa_examples(input_file,is_training=True,use_history=False,n_history=-1):
    '''
    由于CoQA是对话型阅读理解数据集,所以后面的问题依赖于前面的问题与答案,但是这篇论文重点不在于
    专门针对对话型数据集,所以并没有加上之前的问题与答案,要想进一步提升模型在coqa上的效果,是需要加上
    之前的问题与答案的
    '''
    total_cnt=0
    with open(input_file) as reader:
        input_data=json.load(reader)["data"]#input_data是一个list,每一个值是一个dict
    def is_whitespace(c):
        if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
            return True
        return False
    examples=[]
    for entry in input_data:
        #entry是一个dict,keys是['answers', 'filename', 'id', 'questions', 'source', 'story']
        paragraph_text=entry["story"]
        paragraph_id=entry["id"]
        doc_tokens = []
        char_to_word_offset = []
        prev_is_whitespace = True
        for c in paragraph_text:
            if is_whitespace(c):
                prev_is_whitespace = True
            else:
                if prev_is_whitespace:
                    doc_tokens.append(c)
                else:
                    doc_tokens[-1] += c
                prev_is_whitespace = False
            #char_to_word_offset记录的是整个paragraph_text中每一个字符对应的答案的下标,比如I want to.
            #那么char_to_word_offset就是[0,0,1,1,1,1,1,2,2,2],注意空格的位置下标算在它前一个单词的下标
            char_to_word_offset.append(len(doc_tokens) - 1)
        #doc_tokens就是整个paragraph_text利用空白字符分割后得到的单词列表,每一个值就是一个单词
        
        #接下来要做的是处理问答对,也就是每一个问题和对应的答案与paragraph_text构成一个example
        #由于对话型任务的特殊性,因此需要加上之前的问题与答案
        question_history_texts = []#用来记录所有的问题,其实严格来讲
        #还应该定义answer_history_texts用来记录所有的答案,然后拼接
        #entry["questions"]和entry["answers"]是两个长度一样的list
        for (question, ans) in zip(entry['questions'], entry['answers']):
            #question形如:{'input_text': 'When was the Vat formally opened?', 'turn_id': 1}
            #ans形如:{'input_text': 'It was formally established in 1475', 'span_end': 192, 'span_start': 151, 'span_text': 'Formally established in 1475', 
            #     'text': 'Formally established in 1475, although it', 'turn_id': 1, 'yes_no_ans': -1, 'yes_no_flag': 0}
            #text是答案的真实标签,注意这里所谓的真实标签是指:
            #1. 若该问题的答案是yes_no类型的,那么text是答案所在的句子
            #2. 若该问题的答案就是在原文中的一段跨度,那么text就是这段跨度
            #3. 若是该问题的答案不是原文的一段跨度,那么就根据文本片段单词匹配重合度,从原文中找出来一段与答案最相似的文本跨度作为真实标签

            #yes_no_flag=0代表这个问题不是yes_no答案类型的,所有yes_no_ans必定为-1
            #yes_no_flag=1时才表明这个问题是yes或者no,此时yes_no_ans=1表明答案是yes,yes_no_ans=0表明答案是no
            total_cnt += 1  
            cur_question_text = question["input_text"]
            question_history_texts.append(cur_question_text)#添加每一轮的问题
            question_id = question["turn_id"]
            ans_id = ans["turn_id"]
            start_position = None
            end_position =None
            yes_no_flag = None
            yes_no_ans = None
            orig_answer_text = None
            if (question_id != ans_id):
                print("question turns are not ordered!")
                print("mismatched question {}".format(cur_question_text))
            if is_training:
                orig_answer_text = ans["text"]
                answer_offset = ans["span_start"]
                answer_length = len(orig_answer_text)
                start_position = char_to_word_offset[answer_offset]
                if (answer_offset+answer_length >= len(char_to_word_offset)):
                    end_position = char_to_word_offset[-1]
                else:
                    end_position = char_to_word_offset[answer_offset + answer_length]
                #上面几行是用来寻找answer在document中的起始单词的位置和终止单词的位置作为交叉熵的标签
                actual_text = " ".join(doc_tokens[start_position:(end_position+1)])
                cleaned_answer_text = " ".join(whitespace_tokenize(orig_answer_text))
                yes_no_flag = int(ans["yes_no_flag"])
                yes_no_ans = int(ans["yes_no_ans"])
                if actual_text.find(cleaned_answer_text) == -1:
                    logger.warning("Could not find answer: '%s' vs. '%s'",
                                           actual_text, cleaned_answer_text)
                    continue     
            if (use_history):
                if (n_history == -1 or n_history > len(question_history_texts)):
                    question_texts = question_history_texts[:]
                else:
                    question_texts = question_history_texts[-1*n_history:]
            else:
                question_texts = question_history_texts[-1]
            #如果需要添加历史的问题,那么就设定n_history,然后添加到当前问题上,作为一个问题
            example = CoQAExample(
                paragraph_id=paragraph_id,
                turn_id=question_id,
                question_texts=question_texts,
                doc_tokens=doc_tokens,
                orig_answer_text = orig_answer_text,
                start_position=start_position,
                end_position=end_position,
                yes_no_flag=yes_no_flag,
                yes_no_ans=yes_no_ans)
            examples.append(example)
    #这里面每一个字段的含义如下:
    #paragraph_id指的是当前的document在数据集中的id,这个是数据集自带的,turn_id是指
    #当前的问题在当前的document中是第几个问题,question_texts是指当前问题的文本,也有可能是添加了历史问题的文本
    #orig_answer_text是指答案的文本,注意这里是分情况的,如果问题是yes_no类型,那么此时的答案是yes_no所在
    #的问题依据,如果问题的答案不是document的一段跨度,那么就从文章中选出来与答案单词重合度最高的片段作为答案
    logger.info("Total raw examples: {}".format(total_cnt))
    return examples     

在这里插入图片描述
在这里插入图片描述

在解读convert_examples_to_features之前我们先来看下面的例子

在这里插入图片描述

由于BERT采用的wordpiece子词分割的方式,所以Vat,formally,opened都被分开成多个单词。整个句子When was the Vat formally opend?一共有六个单词,被分成了11个单词
  • tok_to_orig_map这个dict记录是子词位置到全词位置的映射关系,比如[0,1,2,3,3,4,4,4,5,5,5]表明,子词序列的第一个、第二个、第三个单词和全词序列的前三个单词一一对应,而子词序列的第四个和第五个单词对应的是全词序列的第四个单词,类似的,子词序列的第6、7、8个单词对应的是全词序列的第五个单词。
  • tok_to_orig_index其实就是与tok_to_orig_map对应的列表,通过列表的下标(列表的下标代表的是子词序列中某个单词的位置)就可以知道对应的全词的位置
  • orig_to_tok_index其实正好反过来,根据列表的下标(列表的下标对应的是全词序列中每一个单词的位置)就可以知道该全词所对应的子词序列中单词的位置。
  • 显然all_doc_tokens是所有的子词序列
def convert_examples_to_features(examples, tokenizer, max_query_length,
                                 is_training=True, append_history=False):
    features = []
    for (example_index, example) in enumerate(examples):
        all_query_tokens = [tokenizer.tokenize(question_text) for question_text in [example.question_texts]]
        #注意,源码这里是有问题的,example.question_texts是一个字符串
        #所以for question_text in example.question_texts显示是在遍历每一个字符,所以这里有问题
        # same as basic Bert
        #all_query_tokens是经过wordpiece分词后的子词序列
        if append_history:
            all_query_tokens = all_query_tokens[::-1]
        flat_all_query_tokens = []
        for query_tokens in all_query_tokens:
            flat_all_query_tokens += query_tokens
        if append_history:
            query_tokens = flat_all_query_tokens[:max_query_length]
        else:
            query_tokens = flat_all_query_tokens[-1*max_query_length:]
        #这几行可以不用考虑,因为之前我们没有添加历史问题
        #query_tokens就是将一个问题文本通过tokenizer进行wordpiece分词
        
        tok_to_orig_index = []
        tok_to_orig_map = {}
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_map[len(all_doc_tokens)] = i
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        #经过前面的分词我们已经知道这几个变量的含义了
        #orig_to_tok_index中的每一个位置的下标代表的单词在全词序列中的位置,对应的值代表的是该词在子词序列中的位置
        #因为输入给BERT的是子词,所以我们的目标是根据全词序列中答案位置找到对应的子词序列中的位置
        tok_start_position = None
        tok_end_position = None
        if is_training:
            tok_start_position = orig_to_tok_index[example.start_position]#我们获得了子词序列中答案的起始位置
            if example.end_position < len(example.doc_tokens) - 1:
                # tok_end_position is the last sub token of orig end_position
                tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
                #注意example.end_position+1的目的是因为python中的列表是不取最后的索引的,比如[:1],那么只会取第一个单词
                #-1是因为列表下标是从0开始的
            else:
                tok_end_position = len(all_doc_tokens) - 1#如果答案的终止位置超过文章子词序列的长度,那么就把最后一个
                #单词作为终止位置
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
                example.orig_answer_text)#这个函数也不必了解,主要是有些特殊字符需要处理
            #那么一旦处理后答案位置就变了,所以这行函数我们不必详细了解
        #现在已经将question和document都进行了wordpiece分词,而且也已经得到了答案在子词序列下的起始位置和终止位置

        features.append(
            ExampleFeature(
                example_index=example_index,
                query_tokens=query_tokens,
                doc_tokens=all_doc_tokens,
                tok_to_orig_map=tok_to_orig_map,
                start_position=tok_start_position,
                end_position=tok_end_position,
                yes_no_flag=example.yes_no_flag,
                yes_no_ans=example.yes_no_ans))
    return features

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
可以看出目前特征feature与样本example没有问题,子词序列和全词序列中的答案位置对应关系也没有问题

有了examples和features就可以用来训练模型了

在这里插入图片描述
上面这张图片展示了我们构造好了各个张量准备送给模型。其中:

  • batch_query_tokens 是batch_size个问题文本
  • batch_doc_tokens 是batch_size个document他们都是经过了wordpiece分词的
  • 其他的不必解释

这里需要注意与BERT不同的地方是,BERT由于是固定分割,并且分割后的segment是独立的,所以假如一篇document被分割成了4个segment,那么就相当于一篇document变成了4个样本,独立的送入BERT当中,而这篇论文的document即使分割后也是一个整体,要保证连续性,所以此时我们的document还没有分割,分割的工作也是模型的一部分,而baseline的方法中,分割的操作是在数据预处理部分完成的,不参与到模型中

接下来把batch_query_tokens, batch_doc_tokens, batch_start_positions, batch_end_positions, batch_max_doc_length送入到模型中

定义一个变量cur_global_pointers用来指示当前的segment相较于上一次的segment的移动步数,比如cur_global_pointers=-16,那么就代表当前的segment应该在上一次的segment的开始位置的基础上在document上向左滑动16的单词.

对于第一次分割,显然cur_global_pointers都是0,因为只能从第一个单词开始.

def gen_model_features(cur_global_pointers, batch_query_tokens, batch_doc_tokens, \
                       batch_start_positions, batch_end_positions, batch_max_doc_length, \
                       max_seq_length, tokenizer, is_train):
    '''
    cur_global_pointers用来指示当前的segment应该在上一次的segment的基础上如何移动
    函数的目的就是根据cur_global_pointers,在document上重新分割,
    重新分割后的start_position和end_position要发生变化
    '''
    chunk_doc_offsets = []
    chunk_doc_tokens = []
    chunk_start_positions = []
    chunk_end_positions = []
    chunk_stop_flags = []
    for index in range(len(cur_global_pointers)):
        # span: [doc_start, doc_span)
        doc_start = max(0, cur_global_pointers[index])#有可能出现第一次预测后模型希望向左移动分割
        doc_end = min(doc_start + batch_max_doc_length[index], len(batch_doc_tokens[index]))
        if (doc_start >= len(batch_doc_tokens[index])):
            doc_end = len(batch_doc_tokens[index])
            doc_start = max(0, doc_end - batch_max_doc_length[index])
        chunk_doc_tokens.append(batch_doc_tokens[index][doc_start:doc_end])
        chunk_doc_offsets.append(doc_start)
        if is_train:
            one_doc_len = doc_end - doc_start
            one_start_position = batch_start_positions[index] - doc_start#修改答案的位置
            one_end_position = batch_end_positions[index] - doc_start
            # 上面的几行代码和baseline的做法一致
            # 注意下面的代码,在BERT中,如果分割后的segment不包含有答案,那么是不作为样本训练模型的
            #但是在这篇论文中整个document是一个整体,要保留所有的segment
            if (one_start_position < 0 or one_start_position >= one_doc_len or \
                one_end_position < 0 or one_end_position >= one_doc_len):
                chunk_stop_flags.append(0)
                chunk_start_positions.append(max_seq_length)
                chunk_end_positions.append(max_seq_length)
                #对于那些不包含answer的segment,我们标记该segment不包含答案
            else:
                chunk_stop_flags.append(1)
                chunk_start_positions.append(one_start_position)
                chunk_end_positions.append(one_end_position)

    # 经过上面的代码我们已经分割得到了一个segment,下面就是把question和这个segment连接构成[CLS]question[SEP]segment[SEP]
    chunk_input_ids = []
    chunk_segment_ids = []
    chunk_input_mask = []
    # position in input_ids to position in batch_doc_tokens
    id_to_tok_maps = []
    for index in range(len(cur_global_pointers)):
        one_id_to_tok_map = {}
        one_query_tokens = batch_query_tokens[index]
        one_doc_tokens = chunk_doc_tokens[index]
        one_doc_offset = chunk_doc_offsets[index]
        one_tokens = []
        one_segment_ids = []
        one_tokens.append("[CLS]")
        one_segment_ids.append(0)
        # add query tokens
        for token in one_query_tokens:
            one_tokens.append(token)
            one_segment_ids.append(0)
        one_tokens.append("[SEP]")
        one_segment_ids.append(0)
        # add doc tokens
        for (i, token) in enumerate(one_doc_tokens):
            one_id_to_tok_map[len(one_tokens)] = one_doc_offset + i
            one_tokens.append(token)
            one_segment_ids.append(1)
        one_tokens.append("[SEP]")
        one_segment_ids.append(1)
        id_to_tok_maps.append(one_id_to_tok_map)
        #这些代码和BERT是一模一样的
        # gen features
        one_input_ids = tokenizer.convert_tokens_to_ids(one_tokens)
        one_input_mask = [1] * len(one_input_ids)
        while len(one_input_ids) < max_seq_length:
            one_input_ids.append(0)
            one_input_mask.append(0)
            one_segment_ids.append(0)
        assert len(one_input_ids) == max_seq_length
        assert len(one_input_mask) == max_seq_length
        assert len(one_segment_ids) == max_seq_length
        chunk_input_ids.append(one_input_ids[:])
        chunk_input_mask.append(one_input_mask[:])
        chunk_segment_ids.append(one_segment_ids[:])
        if is_train:
            # adjust start_positions and end_positions due to doc offsets caused by query and CLS/SEP tokens in the input feature
            chunk_start_positions[index] += len(one_query_tokens) + 2
            chunk_end_positions[index] += len(one_query_tokens) + 2

    return chunk_input_ids, chunk_input_mask, chunk_segment_ids, id_to_tok_maps, \
           chunk_start_positions, chunk_end_positions, chunk_stop_flags

我们来看看gen_model_features获得了什么

我们重新写下代码

train_indices=torch.arange(len(features),dtype=torch.long)#生成从0到所有features_nums的张量
train_sampler=SequentialSampler(train_indices)#注意这里我改成了顺序取数据
train_dataloader=DataLoader(train_indices,sampler=train_sampler,batch_size=6,drop_last=True)
for step, batch_indices in enumerate(tqdm(train_dataloader, desc="Iteration")):
    batch_features = [features[ind] for ind in batch_indices]
    batch_query_tokens = [f.query_tokens for f in batch_features]
    batch_doc_tokens = [f.doc_tokens for f in batch_features]
    batch_start_positions = [f.start_position for f in batch_features]
    batch_end_positions = [f.end_position for f in batch_features]
    batch_yes_no_flags = [f.yes_no_flag for f in batch_features]
    batch_yes_no_answers = [f.yes_no_ans for f in batch_features]

#由于是顺序取数据,所以examples[i]和batch_doc_tokens[i]是对应的
i=5
print(examples[i].orig_answer_text)
print(examples[i].start_position,examples[i].end_position)
print(examples[i].doc_tokens[examples[i].start_position:examples[i].end_position+1])

print(features[i].start_position,features[i].end_position)
print(features[i].tok_to_orig_map[features[i].start_position],
     features[i].tok_to_orig_map[features[i].end_position])
print(features[i].doc_tokens[features[i].start_position:features[i].end_position+1])

#上面的几行代码用来显示我们的examples和features是否是对应的
max_seq_length=256#最大长度设置为256
cur_global_pointers=[0]*6
batch_max_doc_length=[max_seq_length-3-len(query_tokens) for query_tokens in batch_query_tokens]
chunk_input_ids, chunk_input_mask, chunk_segment_ids, id_to_tok_maps, chunk_start_positions, chunk_end_positions, chunk_stop_flags=gen_model_features(cur_global_pointers, batch_query_tokens, batch_doc_tokens, batch_start_positions, batch_end_positions, batch_max_doc_length, max_seq_length, tokenizer, is_train=True)
print(chunk_start_positions,chunk_end_positions)
i=3
print(chunk_start_positions[i],chunk_end_positions[i])
print(id_to_tok_maps[i][chunk_start_positions[i]],
      id_to_tok_maps[i][chunk_end_positions[i]])
orig_answer_start_position=id_to_tok_maps[i][chunk_start_positions[i]]
orig_answer_end_position=id_to_tok_maps[i][chunk_end_positions[i]]
print(batch_doc_tokens[i][orig_answer_start_position:orig_answer_end_position+1])
print(examples[i].orig_answer_text)

运行结果
在这里插入图片描述
看来切分后的chunk_start_positions和chunk_end_positions是没有错的。
chunk_input_ids, chunk_input_mask, chunk_segment_ids以及chunk_start_positions, chunk_end_positions就是模型需要的输入和标签,额外多了chunk_stop_flags用来标记segment是否包含有answer

device=torch.device("cpu")
chunk_input_ids = torch.tensor(chunk_input_ids, dtype=torch.long, device=device)
chunk_input_mask = torch.tensor(chunk_input_mask, dtype=torch.long, device=device)
chunk_segment_ids = torch.tensor(chunk_segment_ids, dtype=torch.long, device=device)
chunk_start_positions = torch.tensor(chunk_start_positions, dtype=torch.long, device=device)
chunk_end_positions = torch.tensor(chunk_end_positions, dtype=torch.long, device=device)
chunk_yes_no_flags = torch.tensor(batch_yes_no_flags, dtype=torch.long, device=device)
chunk_yes_no_answers = torch.tensor(batch_yes_no_answers, dtype=torch.long, device=device)
chunk_stop_flags = torch.tensor(chunk_stop_flags, dtype=torch.long, device=device)

上面的代码是将list转为了torch tensor,接下来终于可以输入到模型中了

这里一共有8个tensor,前5个是任何的抽取式阅读理解任务必须的,yes_no_flags用来表明问题是否是yes_no类型的问题,如果是,那么yes_no_flag=1,对应的yes_no_answers则代表该答案具体是yes还是no,如果yes_no_flag=0,那么yes_no_answer一定为-1。stop_flags用来标记该segment是否包含answer

##########################模型部分#########################

第一部分LSTM的循环机制:

在这里插入图片描述

class recurLSTMNetwork(nn.Module):
    """
    对应着论文中提出的循环机制,该类就是一个LSTM
    """
    def __init__(self, input_size, hidden_size):
        super(recurLSTMNetwork, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        print("LSTM recurrence...")

    def forward(self, x_t, lstm_prev_states):
        """
        输入有两个,x_t对应论文中的v_t,也就是当前时间步的CLS的hidden_state
        lstm_prev_states就是上一个时间步的LSTM的hidden_state
        """
        bsz = x_t.size(0)
        if lstm_prev_states is None:
            lstm_prev_states = (torch.zeros([bsz, self.hidden_size], device=x_t.device), \
                                torch.zeros([bsz, self.hidden_size], device=x_t.device))
        
        hidden_states, cell_states = self.lstm(x_t, lstm_prev_states)
        return (hidden_states, cell_states)#这个hidden_states会被用来作为下一个segment的lstm_prev_states

第二部分Chunking Scorer

在这里插入图片描述
其中的 v ~ c \tilde{v}_c v~c就是当前的LSTM的输出

class stopNetwork(nn.Module):
    """
    chunk_states就是当前时间步的LSTM的输出,用来判断当前的segment是否包含answer
    """
    def __init__(self, input_size):
        super(stopNetwork, self).__init__()
        self.fc = nn.Linear(input_size, 2)

    def forward(self, chunk_states):
        stop_logits = self.fc(chunk_states)
        return stop_logits

第三部分,Chunking Policy,也就是策略网络,用来决定下一次应该如何分割

在这里插入图片描述

class moveStrideNetwork(nn.Module):
    """
    传进来chunk_states是LSTM的hidden_state,所以chunk_states.shape==(batch_size,lstm_hidden_dim)
    按照论文的公式8,我们要将chunk_states通过输出维度是action_space的全连接层,然后softmax就得到了在
    状态空间上的概率分布,然后采样就会输出当前segment下模型采取的action,这个action就是{-16,16,32,64,128}中的某一个
    """
    def __init__(self, input_size, num_action_classes):
        super(moveStrideNetwork, self).__init__()
        self.fc = nn.Linear(input_size, num_action_classes)

    def forward(self, chunk_states, scheme="sample"):
        # stride_probs: (bsz, num_stride_choices)
        outputs = self.fc(chunk_states)
        stride_probs = F.softmax(outputs, dim=1)
        stride_log_probs = F.log_softmax(outputs, dim=1)
        if scheme == "sample":
            policy = Categorical(stride_probs.detach())
            # sampled_stride_inds: (bsz,)
            sampled_stride_inds = policy.sample()
        elif scheme == "greedy":
            # sampled_stride_inds: (bsz,)
            sampled_stride_inds = torch.argmax(stride_probs.detach(), dim=1)
        # sampled_stride_log_probs: (bsz, )
        sampled_stride_log_probs = stride_log_probs.gather(1, sampled_stride_inds.unsqueeze(1)).squeeze(1)
        #sampled_stride_inds就是长度为batch_size的tensor,每一个值代表每一个样本在各自的segment下做出的行为,这个值
        #也就是cur_global_pointers
        #sampled_stride_log_probs是做出该行为的log概率
        return sampled_stride_inds, sampled_stride_log_probs

模型整体流程:

class RCMBert(BertPreTrainedModel):
    def __init__(self, config, action_num, recur_type="gated", allow_yes_no=False):
        super(RCMBert, self).__init__(config)
        self.bert = BertModel(config)
        self.recur_type = recur_type
        self.allow_yes_no = allow_yes_no
        if recur_type == "gated":
            self.recur_network = recurGatedNetwork(config.hidden_size, config.hidden_size)
        elif recur_type == "lstm":
            self.recur_network = recurLSTMNetwork(config.hidden_size, config.hidden_size)
        else:
            print("Invalid recur_type: {}".format(recur_type))
            sys.exit(0)
        self.action_num = action_num
        self.stop_network = stopNetwork(config.hidden_size)
        self.move_stride_network = moveStrideNetwork(config.hidden_size, self.action_num)
        
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        if self.allow_yes_no:
            self.yes_no_flag_outputs = nn.Linear(config.hidden_size, 2)
            self.yes_no_ans_outputs = nn.Linear(config.hidden_size, 2)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

        self.init_weights()


    def forward(self, input_ids, token_type_ids, attention_mask,
                prev_hidden_states, stop_flags=None,
                start_positions=None, end_positions=None,
                yes_no_flags=None, yes_no_answers=None):
        """
        模型的整体流程是:
        1. 将input_ids通过BERT获得了sequence_output,形状为(batch_size,max_seq_length,768)
        2. 将sequence_output所有序列的第一个值,也就是所有的CLS的表示拿出来得到sent_output 形状为(batch_size,768)
        3. 将sent_output也就是CLS表示,通过LSTM,注意这里还要有上一次的segment的LSTM的输出,得到recur_sent_output
        4. 得到了LSTM的输出之后,
            4.1: 将LSTM的输出通过stopNetwork网络,得到论文中的q_c用来判断当前的segment包含answer的概率
            4.2: 将LSTM的输出通过StridePolicyNetwork,得到行为的输出,以及采取这个行为的概率
        5. 将sequence_output用来做答案的预测,得到开始位置和结束位置的概率
        
        6. 获得了上面的值之后,就可以计算loss了
        
        ## 这里需要注意的是对于那些有yes_no问题的数据集,需要把yes_no的问题单独处理,自然也要有单独的loss
        """
        outputs = self.bert(input_ids, attention_mask=attention_mask,
                                       token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        # add dropout
        sequence_output = self.dropout(sequence_output)
        # sent_output: (batch_size, hidden_size)
        sent_output = sequence_output.narrow(1, 0, 1)
        sent_output = sent_output.squeeze(1)
        
        # combine hidden_states for stop and moving prediction
        if self.recur_type == "gated":
            cur_hidden_states = sent_output if prev_hidden_states is None \
                                else self.recur_network(sent_output, prev_hidden_states)
            recur_sent_output = cur_hidden_states
        elif self.recur_type == "lstm":
            cur_hidden_states = self.recur_network(sent_output, prev_hidden_states)
            recur_sent_output = cur_hidden_states[0]

        # stop logits: (bsz, 2)
        stop_logits = self.stop_network(recur_sent_output)
        
        # answer prediction in the current span
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        # yes-no question
        if self.allow_yes_no:
        #如果允许yes_no,那么就把CLS的表示通过单独用来处理yes_no的网络层
            yes_no_flag_logits = self.yes_no_flag_outputs(sent_output)
            yes_no_ans_logits = self.yes_no_ans_outputs(sent_output)
            yes_no_flag_logits = yes_no_flag_logits.squeeze(-1)
            yes_no_ans_logits = yes_no_ans_logits.squeeze(-1)
        
        # get loss for stop & answer prediction in the chunk level
        if start_positions is not None and end_positions is not None and \
           stop_flags is not None:
            # stride 
            sampled_stride_inds, sampled_stride_log_probs = self.move_stride_network(recur_sent_output, scheme="sample")
            
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            if len(stop_flags.size()) > 1:
                stop_flags = stop_flags.squeeze(-1)

            # stop loss
            stop_loss_fct = CrossEntropyLoss(reduction='mean')
            stop_loss = stop_loss_fct(stop_logits, stop_flags)

            # answer loss
            if self.allow_yes_no and yes_no_flags is not None and \
               yes_no_answers is not None:
                # ground truth
                if len(yes_no_flags.size()) > 1:
                    yes_no_flags = yes_no_flags.squeeze(-1)
                if len(yes_no_answers.size()) > 1:
                    yes_no_answers = yes_no_answers.squeeze(-1)
                    
                # for all samples: classify yes-no / wh- question
                # this is purely query-dependent, and not influenced by stop_flags
                flag_loss_fct = CrossEntropyLoss(reduction='mean')
                yes_no_flag_loss = flag_loss_fct(yes_no_flag_logits, yes_no_flags)
                answer_loss = 0.25 * yes_no_flag_loss
                #额外预测每一个example的问题是否是yes_no的问题
                
                # estimate loss only when the current chunk contains the answer
                yes_no_indices = (stop_flags + yes_no_flags == 2).nonzero().view(-1)
                wh_indices = (stop_flags - yes_no_flags == 1).nonzero().view(-1)
                #需要注意的是yes_no问题需要单独计算他们的loss
                # for samples with wh- questions
                selected_start_positions = start_positions.index_select(0, wh_indices)
                selected_end_positions = end_positions.index_select(0, wh_indices)
                selected_start_logits = start_logits.index_select(0, wh_indices)
                selected_end_logits = end_logits.index_select(0, wh_indices)
                # sometimes the start/end positions are outside our model inputs, we ignore these terms
                # here index is word index instead of sample index
                ignored_index = selected_start_logits.size(1)
                selected_start_positions.clamp_(0, ignored_index)
                selected_end_positions.clamp_(0, ignored_index)
                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                if (selected_start_positions.size(0) > 0):
                    start_loss = loss_fct(selected_start_logits, selected_start_positions)
                    end_loss = loss_fct(selected_end_logits, selected_end_positions)
                    answer_loss += 0.25 * start_loss + 0.25 * end_loss

                # for samples with yes-no questions
                # CrossEntropyLoss: input: (seq_len, C), target: (seq_len, )
                selected_yes_no_ans_logits = yes_no_ans_logits.index_select(0, yes_no_indices)
                selected_yes_no_answers = yes_no_answers.index_select(0, yes_no_indices)
                ans_loss_fct = CrossEntropyLoss(reduction='mean')
                if (selected_yes_no_ans_logits.size(0) > 0):
                    yes_no_ans_loss = ans_loss_fct(selected_yes_no_ans_logits, \
                                                   selected_yes_no_answers)
                    answer_loss += 0.25 * yes_no_ans_loss
                return stop_logits, sampled_stride_inds, sampled_stride_log_probs, \
                       start_logits, end_logits, yes_no_flag_logits, yes_no_ans_logits, \
                       cur_hidden_states, stop_loss, answer_loss
                    
            else:
                # only answer span selection
                ignored_index = start_logits.size(1)
                start_positions.clamp_(0, ignored_index)
                end_positions.clamp_(0, ignored_index)
                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                start_loss = loss_fct(start_logits, start_positions)
                end_loss = loss_fct(end_logits, end_positions)
                answer_loss = 0.5 * start_loss + 0.5 * end_loss
                
                return stop_logits, sampled_stride_inds, sampled_stride_log_probs, \
                       start_logits, end_logits, cur_hidden_states, stop_loss, answer_loss

        else:
            # stride 
            sampled_stride_inds, sampled_stride_log_probs = self.move_stride_network(recur_sent_output, scheme="greedy")
            if self.allow_yes_no:
                return stop_logits, sampled_stride_inds, sampled_stride_log_probs, \
                       start_logits, end_logits, yes_no_flag_logits, yes_no_ans_logits, \
                       cur_hidden_states
            else:
                return stop_logits, sampled_stride_inds, sampled_stride_log_probs, \
                       start_logits, end_logits, cur_hidden_states
这里需要注意的是answer loss,对于有yes_no问题的阅读理解任务,此时的answer_loss=0.25*yes_no_flag_loss+0.25*yes_no_ans_loss+0.25*start_loss+0.25*end_loss

对于不包含yes_no问题的任务,answer_loss就是指开始位置和结束位置的交叉熵损失。

在训练阶段,模型返回的张量如下:

stop_logits, sampled_stride_inds, sampled_stride_log_probs,  start_logits, end_logits, yes_no_flag_logits, yes_no_ans_logits,  cur_hidden_states, stop_loss, answer_loss

其中:

  1. stop_logits是指LSTM的输出通过一个输出维度是2的全连接层,注意此时还没有概率归一化
  2. sampled_stride_inds是指经过policy network后模型做出的行为,也就是在 p a c t ( a ∣ s ) p^{act}(a|s) pact(as)上进行随机采样后得到的输出,比如2,那么对应在行为空间{-16,16,32,64,128}上的值就是32,意味着下一次的分割模型要在原来的基础上向右移动32个单词(这个值也就是cur_global_pointers的值,在get_model_features的具体代码为doc_start = max(0, cur_global_pointers[index]))
  3. sampled_stride_log_probs就是模型做出这个行为的log概率值,也就是 log ⁡ p a c t ( a ∣ s ) \log p^{act}(a|s) logpact(as)
  4. start_logits和end_logits不必细说
  5. yes_no_flag_logits是指模型预测该问题是否是yes_no问题的分数
  6. yes_no_ans_logits是指模型预测该问题的答案是yes或者no的分数
  7. cur_hidden_states就是当前的LSTM的输出,用来作为下一个segment的LSTM的输入,达到循环机制的目的
  8. 模型的loss由三部分组成,stop_loss是指 L c s L_{cs} Lcs,answer_loss是指 L a n s L_{ans} Lans

接下来就是根据sampled_stride_log_probs以及stop_logits还有start_logits和end_logits计算 ∇ L c p \nabla L_{cp} Lcp

其中:

  • sampled_stride_log_probs对应 log ⁡ p a c t ( a ∣ s ) \log p^{act}(a|s) logpact(as)
  • stop_logits经过softmax后对应的是 q c q_c qc
  • start_logits经过softmax后对应的是 p c , i s t a r t p_{c,i}^{start} pc,istart
  • end_logits经过softmax后对应的是 p c , j e n d p_{c,j}^{end} pc,jend

接下来根据模型的输出构建loss

stride_log_probs = []
stop_rewards = []
stop_probs = []
cur_global_pointers = [0] * batch_size # global position of current pointer at the document
batch_max_doc_length = [args.max_seq_length-3-len(query_tokens) for query_tokens in batch_query_tokens]
for _ in range(max_read_times):
    #需要注意的是目前是在max_read_times次数下循环进行的
    chunk_stop_logits, chunk_stride_inds, chunk_stride_log_probs, \
                       chunk_start_logits, chunk_end_logits, \
                       chunk_yes_no_flag_logits, chunk_yes_no_ans_logits, \
                       prev_hidden_states, chunk_stop_loss, chunk_answer_loss = \
                       model(chunk_input_ids, chunk_segment_ids, chunk_input_mask,
                             prev_hidden_states, chunk_stop_flags,
                             chunk_start_positions, chunk_end_positions,
                             chunk_yes_no_flags, chunk_yes_no_answers)
    #我们已经获得了模型的输出,各个变量的含义也已经解释了
    chunk_stop_logits = chunk_stop_logits.detach()
    chunk_stop_probs = F.softmax(chunk_stop_logits, dim=1)
    chunk_stop_probs = chunk_stop_probs[:, 1]#chunk_stop_probs有两列,第二列代表是预测当前segment包含答案的概率
    stop_probs.append(chunk_stop_probs.tolist())#chunk_stop_probs是当前这一次下的q_c
    #由于奖赏是要max_read_times结束后才能计算出来,所以需要记住每一次的q_c和r_c
    #q_c对应chunk_stop_probs
    chunk_stop_logits = chunk_stop_logits.tolist()
    #模型的loss由三部分组成,其中的stop_Loss和answer_loss在每一次下都会计算出来,所以累加max_read_times然后取平均即可
    if stop_loss is None:
        stop_loss = chunk_stop_loss
    else:
        stop_loss += chunk_stop_loss

    if answer_loss is None:
        answer_loss = chunk_answer_loss
    else:
        answer_loss += chunk_answer_loss
    #但是rl_loss,也就是强化学习的那部分loss必须要等到max_read_times结束后才能计算出来
    # take movement action
    chunk_strides = [stride_action_space[stride_ind] for stride_ind in chunk_stride_inds.tolist()]
    cur_global_pointers = [cur_global_pointers[ind] + chunk_strides[ind] for ind in range(len(cur_global_pointers))]
    # put pointer to 0 or the last doc token
    cur_global_pointers = [min(max(0, cur_global_pointers[ind]), len(batch_doc_tokens[ind])-1) \
                           for ind in range(len(cur_global_pointers))]
    #chunk_stride_inds已经解释过了,就是上一次segment模型做出的行为,显然chunk_strides形如[32,128,-16,16,等等]
    #cur_global_pointers就是根据上一次segment的起始位置加上模型做出的行为,得到了这一次的segment
    #比如上一次的cur_pointers=[0,16,128,32,等等],那么这一次的cur_global_pointers=[32,144,112,48,等等]
    #也就是说这一次的batch_size个样本中,第一个样本从它的document的第32个单词开始切分,第二个样本从它的document的第144个单词开始切分,等等
    
    #接下来是计算r_c:
    chunk_start_probs = F.softmax(chunk_start_logits.detach(), dim=1).tolist()
    chunk_end_probs = F.softmax(chunk_end_logits.detach(), dim=1).tolist()
    #chunk_start_probs.shape==(batch_size,max_seq_length),他表示的是max_seq_length的每一个单词作为答案起始位置的概率
    #chunk_end_probs同理
    chunk_yes_no_flag_probs = F.softmax(chunk_yes_no_flag_logits.detach(), dim=1).tolist()
    chunk_yes_no_ans_probs = F.softmax(chunk_yes_no_ans_logits.detach(), dim=1).tolist()
    # 上面两行代码主要是处理yes_no问题的,如果数据集没有yes_no的问题则不用考虑
    #接下来就是根据start_probs和end_probs以及stop_flags计算r_c,对应的是论文中的公式12
    chunk_stop_rewards = reward_estimation_for_stop(chunk_start_probs, chunk_end_probs,
                                                    chunk_start_positions.tolist(), chunk_end_positions.tolist(),
                                                    chunk_yes_no_flag_probs, chunk_yes_no_ans_probs,
                                                    batch_yes_no_flags, batch_yes_no_answers, chunk_stop_flags.tolist())
    stop_rewards.append(chunk_stop_rewards)#这个stop_rewards就是论文中的r_c,它的含义是模型从当前的segment中正确提取出答案的概率
    #前面已经说过,强化学习部分的loss只有当整个max_read_times结束后才能计算出来,因此stop_rewards存储的是max_read_times下的r_c
    # save history (exclude the prob of the last read since the last action is not evaluated)
    if (t < args.max_read_times - 1):
        stride_log_probs.append(chunk_stride_log_probs)
        #现在我们已经存储了每一次下的q_c和r_c,也就是说当max_read_times结束后就可以计算R(s,a)了,然而我们还没有存储每一次的log p(a|s)
        #所以stride_log_probs存储的就是每一次的log p(a|s)
        #需要注意的是,最后一次采取的行为是没有奖励的,所以我们不需要考虑最后一次的log p(a|s)
#到了这里我们就可以计算强化学习部分的loss了,我们自己构造下数据
#假设batch_size=2,max_read_times=5,
#stop_probs代表的是这2个document在这5次的切分当中,每一次模型预测segment包含answer的概率,
stop_probs=[[0.15,0.5],
            [0.2,0.7],
            [0.25,0.35],
           [0.40,0.2],
           [0.15,0.30]]#每一个值对应的是q_c
#stop_rewards代表的是这2个document在这5次的切分当中,每一次模型从segment中提取出answer的概率
stop_rewards=[[0.08,0.15],
              [0.0,0.36],
              [0.5,0.0],
             [0.15,0.25],
             [0.0,0.4]]#每一个值对应的是r_c,0.0代表这个segment不包含answer
stop_probs=np.transpose(stop_probs)
stop_rewards=np.transpose(stop_rewards)
print(stop_probs)
print(stop_rewards)

重点来了:::::::::

第一个样本的stop_probs=[0.15 0.2 0.25 0.4 0.15]

第一个样本的stop_rewards=[0.08 0. 0.5 0.15 0. ]

下面我们按照公式 R ( s , a ) = q c r c + ( 1 − q c ) R ( s ′ , a ′ ) R(s,a)=q_cr_c+(1-q_c)R(s',a') R(s,a)=qcrc+(1qc)R(s,a)来手动计算第一个样本所应该获得的奖励

首先我们需要的是从后向前计算,max_read_times=5

  • 第五次由于没有下一次了,所以第五次的奖励设置为第五次的 r c r_c rc,也就是0.
  • 第四次的奖励为0.4*0.15+(1-0.4)*0.=0.06
  • 第三次的奖励为0.25*0.5+(1-0.25)*0.06=0.17
  • 第二次的奖励为0.2*0. +(1-0.2)*0.17=0.136
  • 第一次的奖励为0.15*0.08+(1-0.15)*0.136=0.1276

所以对于第一个样本从第一次到倒数第二次的奖励应该为[0.1276,0.136,0.17,0.06]

但是按照论文的github代码运行出来的结果:
在这里插入图片描述
我觉得我的理解没有问题,感觉好像是论文的代码错了

所以我认为rl_reward.py中的代码应该是下面这样

q_vals = []
# calc from the end to the beginning time
next_q_vals = stop_rewards[:,-1] #np.zeros(len(stop_rewards))
for t in reversed(range(0, stop_rewards.shape[1]-1)):
    t_rewards = stop_rewards[:, t]
    t_probs = stop_probs[:, t]

    cur_q_vals = np.multiply(t_rewards, t_probs) + np.multiply(next_q_vals, 1-t_probs)
    q_vals.append(list(cur_q_vals)[:])
    next_q_vals = cur_q_vals
# q_vals: (bsz, max_read_times-1)
q_vals=list(reversed(q_vals))
q_vals = np.transpose(q_vals)

运行结果:
在这里插入图片描述

假如我的理解是对的,那么我们已经获得了每一个时刻的奖励 R ( s , a ) R(s,a) R(s,a),而且之前也已经把各个时刻的stride_log_probs记录下来,

那么现在 log ⁡ p a c t ( a ∣ s ) \log p^{act}(a|s) logpact(as) R ( s , a ) R(s,a) R(s,a)都已经得到了,

接下来自然就是:
在这里插入图片描述

reinforce_loss=torch.mean(torch.sum(-q_vals*stride_log_probs,dim=1))

最终的loss形式为:

loss = (stop_loss + answer_loss) / args.max_read_times + reinforce_loss

在这里插入图片描述

这篇论文的GitHub代码是有问题的,比如你会发现q_vals变量的值一直是0,也就是模型的奖励始终是0,这就导致强化学习部分的损失函数是没有起到任何作用.

我们看run_RCM_coqa.py中的一段代码

在这里插入图片描述
然而我们和reward_estimation_for_stop函数对比一下:
在这里插入图片描述
你就会发现变量传入的顺序是有问题的。
正确的传参顺序如下:
在这里插入图片描述

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值