【Bert】(十三)简易问答系统--源码解析(测试)

该博客介绍了BERT模型在测试阶段如何处理答案合理性问题,包括处理句子长度不足、起始和终止位置超出范围、答案位于问题区域内、滑动窗口导致的单词归属问题以及起始位置在终止位置之后等异常情况。BERT通过选取多个候选起始和终止位置,排除不合理组合,并结合得分排序选择最终答案。此外,还详细解释了get_final_text函数,用于从预处理的文本中恢复原始答案。
摘要由CSDN通过智能技术生成

上一篇博客介绍的损失部分就涉及训练的过程。

本篇介绍一下测试。按照上一篇博客介绍损失时,start_logits选取最大的概率值作为起始位置与真实起始位置比较,end_logits选取最大的概率值作为终止位置与真实终止位置比较。那么直观观念上测试只需要分别选取start_logits和end_logits的最大值,就能得到起始位置和终止位置。

但是会碰到如下几个问题

(1)很多时候句子达不到设定的seq_length的长度,假如设定输入模型的整个句子的向量长度为384,但是实际问题+段落的长度才181,而预测的起始位置为230,预测得到的起始位置大于句子的实际长度,这显然不合理。

(2)预测得到的终止位置大于句子的实际长度,这显然不合理。

(3)由于输入模型的句子向量中,第一部分为问题,第二部分为段落。预测得到的起始位置落在第一部分问题的区域间断,这显然不合理。

(4)由于输入模型的句子向量中,第一部分为问题,第二部分为段落。预测得到的终止位置落在第一部分问题的区域间断,这显然不合理。

(5) 由于在【Bert】(十)简易问答系统--数据解析_mjiansun的专栏-CSDN博客中介绍过滑动窗切割段落的情况,是的部分重叠单词有了归属,也就是部分单词属于切割后的第一个句子,有些单词属于切割后的第二个句子,所以预测的起始位置的单词如果不属于本句话,那么也应该判定为不合理。

(6)如果start_logits最大概率所在位置A,end_logits最大概率所在位置B,A的位置在B的后面,这就不符合逻辑,起始位置怎么能在终止位置后面

(7)一般的答案都会有一个长度,一般都不会太长,如果出现一个答案特别长,这显然也不合理

针对上述问题,bert是这样解决的

选取一定数量的候选起始位置和终止位置,让每一个起始位置和终止位置进行排列组合,然后跳过出现上问题的组合,保留满足条件的组合即可。

    for (feature_index, feature) in enumerate(features):
      result = unique_id_to_result[feature.unique_id]
      start_indexes = _get_best_indexes(result.start_logits, n_best_size)#按照得分选取前n_best_size个对应的候选起始位置
      end_indexes = _get_best_indexes(result.end_logits, n_best_size)#按照得分选取前n_best_size个对应的候选终止位置
      
      for start_index in start_indexes:
        for end_index in end_indexes:
          # We could hypothetically create invalid predictions, e.g., predict
          # that the start of the span is in the question. We throw out all
          # invalid predictions.
          if start_index >= len(feature.tokens):
            continue
          if end_index >= len(feature.tokens):
            continue
          if start_index not in feature.token_to_orig_map:
            continue
          if end_index not in feature.token_to_orig_map:
            continue
          if not feature.token_is_max_context.get(start_index, False):
            continue
          if end_index < start_index:
            continue
          length = end_index - start_index + 1
          if length > max_answer_length:
            continue
          prelim_predictions.append(
              _PrelimPrediction(
                  feature_index=feature_index,
                  start_index=start_index,
                  end_index=end_index,
                  start_logit=result.start_logits[start_index],
                  end_logit=result.end_logits[end_index]))

剔除完一些不合理情况后,如何选出唯一的一个组合

(1)综合起始位置和终止位置的综合得分排序

    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_logit + x.end_logit),
        reverse=True)

将起始位置和终止位置的得分综合起来作为该对组合的最终得分,并按照最终得分排序。

(2)根据位置获取预测的字符串

根据tokens和起始终止位置,可以得到tokens拼接出来的字符串tok_text。根据examples和起始终止位置,可以得到examples拼接出来的字符串orig_text。

通过get_final_text函数综合得到最终字符串final_text,在下面我们会介绍该函数,这里先讲逻辑。

通过n_best_size限制保存的数量,通过seen_predictions防止重复答案。

    for pred in prelim_predictions:
      if len(nbest) >= n_best_size:
        break
      feature = features[pred.feature_index]
      if pred.start_index > 0:  # this is a non-null prediction
        tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]#获取经过tokenizer处理过后的tokens的答案字符串
        orig_doc_start = feature.token_to_orig_map[pred.start_index]
        orig_doc_end = feature.token_to_orig_map[pred.end_index]
        orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] #获取原始句子example中的答案字符串
        tok_text = " ".join(tok_tokens)

        # De-tokenize WordPieces that have been split off.之前分词添加的##需要还原到原来的词
        tok_text = tok_text.replace(" ##", "")
        tok_text = tok_text.replace("##", "")

        # Clean whitespace
        tok_text = tok_text.strip()
        tok_text = " ".join(tok_text.split())
        orig_text = " ".join(orig_tokens)

        final_text = get_final_text(tok_text, orig_text, do_lower_case)
        if final_text in seen_predictions:
          continue

        seen_predictions[final_text] = True
      else:
        final_text = ""
        seen_predictions[final_text] = True

      nbest.append(
          _NbestPrediction(
              text=final_text,
              start_logit=pred.start_logit,
              end_logit=pred.end_logit))

这里我讲解下get_final_text函数, 首先他注释中的例子我看懂了,但是与代码似乎不对应(是我理解错了?)

1)找出起始和终止位置

将orig_text使用tokenizer进行处理,然后比较找出pred_text在orig_text的起始位置。

再根据pred_text的长度得到终止位置。

  tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)

  tok_text = " ".join(tokenizer.tokenize(orig_text))

  start_position = tok_text.find(pred_text)
  if start_position == -1:
    if FLAGS.verbose_logging:
      tf.logging.info(
          "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
    return orig_text
  end_position = start_position + len(pred_text) - 1

2)剔除空格的影响

orig_ns_text:剔除空格后的字符串

orig_ns_to_s_map:字典,键表示字符在orig_ns_text中的位置,值表示字符在剔除空格字符串的位置

tok_ns_text:剔除空格后的字符串

tok_ns_to_s_map:字典,键表示字符在orig_ns_text中的位置,值表示字符在剔除空格字符串的位置

  def _strip_spaces(text):
    ns_chars = []
    ns_to_s_map = collections.OrderedDict()
    for (i, c) in enumerate(text):
      if c == " ":
        continue
      ns_to_s_map[len(ns_chars)] = i
      ns_chars.append(c)
    ns_text = "".join(ns_chars)
    return (ns_text, ns_to_s_map)  
  (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
  (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

例如orig_ns_text:'DenverBroncosdefeatedtheNationalFootballConference(NFC)championCarolinaPanthers'

orig_ns_to_s_map:

OrderedDict([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 15), (14, 16), (15, 17), (16, 18), (17, 19), (18, 20), (19, 21), (20, 22), (21, 24), (22, 25), (23, 26), (24, 28), (25, 29), (26, 30), (27, 31), (28, 32), (29, 33), (30, 34), (31, 35), (32, 37), (33, 38), (34, 39), (35, 40), (36, 41), (37, 42), (38, 43), (39, 44), (40, 46), (41, 47), (42, 48), (43, 49), (44, 50), (45, 51), (46, 52), (47, 53), (48, 54), (49, 55), (50, 57), (51, 58), (52, 59), (53, 60), (54, 61), (55, 63), (56, 64), (57, 65), (58, 66), (59, 67), (60, 68), (61, 69), (62, 70), (63, 72), (64, 73), (65, 74), (66, 75), (67, 76), (68, 77), (69, 78), (70, 79), (71, 81), (72, 82), (73, 83), (74, 84), (75, 85), (76, 86), (77, 87), (78, 88)])

tok_ns_text:'denverbroncosdefeatedthenationalfootballconference(nfc)championcarolinapanthers'

tok_ns_to_s_map:

OrderedDict([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 15), (14, 16), (15, 17), (16, 18), (17, 19), (18, 20), (19, 21), (20, 22), (21, 24), (22, 25), (23, 26), (24, 28), (25, 29), (26, 30), (27, 31), (28, 32), (29, 33), (30, 34), (31, 35), (32, 37), (33, 38), (34, 39), (35, 40), (36, 41), (37, 42), (38, 43), (39, 44), (40, 46), (41, 47), (42, 48), (43, 49), (44, 50), (45, 51), (46, 52), (47, 53), (48, 54), (49, 55), (50, 57), (51, 59), (52, 60), (53, 61), (54, 63), (55, 65), (56, 66), (57, 67), (58, 68), (59, 69), (60, 70), (61, 71), (62, 72), (63, 74), (64, 75), (65, 76), (66, 77), (67, 78), (68, 79), (69, 80), (70, 81), (71, 83), (72, 84), (73, 85), (74, 86), (75, 87), (76, 88), (77, 89), (78, 90)])

3)从tok的起始终止位置转成orig中的起始终止位置

tok_text的起始位置找到tok_ns_text的位置,然后根据tok_ns_text位置和orig_ns_text位置一一对应的规则,得出了起始位置在orig_ns_text中的位置,再根据orig_ns_to_s_map的映射规则得到起始位置在orig_text的位置。

  # We then project the characters in `pred_text` back to `orig_text` using
  # the character-to-character alignment.
  tok_s_to_ns_map = {}
  for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
    tok_s_to_ns_map[tok_index] = i

  orig_start_position = None
  if start_position in tok_s_to_ns_map:
    ns_start_position = tok_s_to_ns_map[start_position]
    if ns_start_position in orig_ns_to_s_map:
      orig_start_position = orig_ns_to_s_map[ns_start_position]

  if orig_start_position is None:
    if FLAGS.verbose_logging:
      tf.logging.info("Couldn't map start position")
    return orig_text

  orig_end_position = None
  if end_position in tok_s_to_ns_map:
    ns_end_position = tok_s_to_ns_map[end_position]
    if ns_end_position in orig_ns_to_s_map:
      orig_end_position = orig_ns_to_s_map[ns_end_position]

  if orig_end_position is None:
    if FLAGS.verbose_logging:
      tf.logging.info("Couldn't map end position")
    return orig_text

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值