“基于常识知识的推理问题”源代码分析-受影响的源代码

2021SC@SDUSC

在前几次的源代码分析报告中,我已经对于DrFact的算法的前两步进行了分析。按照顺序,我今天应该首先分析算法的第三步在源代码中是如何实现的。不过,基于前两周中我对于DrFact模型是如何参考了DrKit模型的阐述,我想要趁热打铁,在今天着重分析一下剩余的源代码中受到了这些参考影响的源代码,也算是对模型中不属于算法主体部分的内容进行一下系统化的阐述。

一、input_fns.py模块

input_fns.py是DrFact模型中,一个极为重要的源代码模块。在这个源代码模块中存放有多个类,其名分别为Example,InputFeatures,FeatureWriter和OpenCSRDataset。这些类的主要用途是用于将不同的数据集处理为一种公共格式。在这个py源文件中,引入了如下的python模块,其中就有来自于DrKit模块的参考input_fns。

import collections
import json
import random

from bert import tokenization
from language.labs.drkit import input_fns as input_utils
import tensorflow.compat.v1 as tf
from tqdm import tqdm

from tensorflow.contrib import data as contrib_data

接下来,我将对这个源文件中定义到的几个类进行分析。 

1.1 Example类

首先,我将介绍input_fns.py源文件中的Example类,这个类十分重要,是后续多个类的依赖。它的构造函数如下所示,可以看到,该类共有9个成员变量,且这9个成员变量均通过构造函数的形参进行初始化。不过在此之中,问题实体answer_entity,概念选择项choice2concepts,正确的选择correct_choice,排除掉的回答集合exclude_set,初始事实集合init_facts和支持事实集合sup_facts的缺省值均为None。

class Example(object):
    """A single training/test example for QA."""

    def __init__(
            self,
            qas_id,
            question_text,
            subject_entity,  # The concepts mentioned by the question.
            answer_entity=None,  # The concept(s) in the correct choice
            choice2concepts=None,  # for evaluation
            correct_choice=None,  # for evaluation
            exclude_set=None,  # concepts in the question and wrong choices
            init_facts=None,  # pre-computed init facts list. [(fid, score), ...]
            sup_facts=None
    ):
        self.qas_id = qas_id
        self.question_text = question_text
        self.subject_entity = subject_entity
        self.answer_entity = answer_entity
        self.choice2concepts = choice2concepts
        self.correct_choice = correct_choice
        self.exclude_set = exclude_set
        self.init_facts = init_facts
        self.sup_facts = sup_facts

此外,Example类还有两个内置函数,分别是__str__()和__repr__()。意义与一般类相同,在此不多赘述。

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        s = ""
        s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
        s += ", question_text: %s" % (
            tokenization.printable_text(self.question_text))
        return s

可见,Example类的作用是存放一个问答中单个的训练/测试样本。

1.2 InputFeature类

第二个要介绍的类是InputFeature类,根据其注释可知,这个类的作用是存放一个数据集中的单个特征集合。这个类的定义也很简单,只有一个构造函数,其具体代码如下。可以看出,这个类共有10个成员变量,且均根据构造形参初始化,而其中的与Example中缺省None的几个变量也均缺省值为None。

class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self,
                 qas_id,
                 qry_tokens,
                 qry_input_ids,
                 qry_input_mask,
                 qry_entity_id,
                 answer_entity=None,
                 exclude_set=None,
                 init_facts=None,
                 sup_fact_1hop=None,
                 sup_fact_2hop=None, ):
        self.qas_id = qas_id
        self.qry_tokens = qry_tokens
        self.qry_input_ids = qry_input_ids
        self.qry_input_mask = qry_input_mask
        self.qry_entity_id = qry_entity_id
        self.answer_entity = answer_entity
        self.exclude_set = exclude_set
        self.init_facts = init_facts
        self.sup_fact_1hop = sup_fact_1hop
        self.sup_fact_2hop = sup_fact_2hop

3. FeatureWriter类

第三个要介绍的类是FeatureWriter类,这个类会稍微复杂一些,它的构造函数的具体代码如下。可以看出这个类共有5个成员变量,其中num_features初始化为0,私有变量_writer则使用了tensorflow.python_io.TFRecordWriter()函数,根据形参中的filename构造。而其余的filename,is_training和has_bridge则均由形参获得。

class FeatureWriter(object):
    """Writes InputFeature to TF example file."""

    def __init__(self, filename, is_training, has_bridge):
        self.filename = filename
        self.is_training = is_training
        self.has_bridge = has_bridge
        self.num_features = 0
        self._writer = tf.python_io.TFRecordWriter(filename)

除了构造函数外,该类中还有两个函数,分别是process_feature()函数和close()函数。

process_feature()函数的作用是将独立特征写到TFRecordWriter中,作为一个训练样例。具体实现方式如代码所示,即是通过三个内置函数,将特征feature中的qas_id,qry_input_ids,qry_input_mask,qry_entity_id等成员值经过处理后,存到新的有序字典features的qas_ids等值中。其中有如answer_entities和exclude_set这样的成员值则需要在is_training为真时才处理。

至于close()函数则很简单,即是将私有成员对象_writer关闭。

    def process_feature(self, feature):
        """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
        # The feature object is actually of Example class.
        self.num_features += 1

        def create_int_feature(values):
            feature = tf.train.Feature(
                int64_list=tf.train.Int64List(value=list(values)))
            return feature

        def create_float_feature(values):
            feature = tf.train.Feature(
                float_list=tf.train.FloatList(value=list(values)))
            return feature

        def create_bytes_feature(value):
            return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

        features = collections.OrderedDict()
        features["qas_ids"] = create_bytes_feature(feature.qas_id)
        features["qry_input_ids"] = create_int_feature(feature.qry_input_ids)
        features["qry_input_mask"] = create_int_feature(feature.qry_input_mask)
        features["qry_entity_id"] = create_int_feature(feature.qry_entity_id)
        # Prepare Init Facts as features
        init_fact_ids = [x[0] for x in feature.init_facts]
        init_fact_scores = [x[1] for x in feature.init_facts]

        features["init_fact_ids"] = create_int_feature(init_fact_ids)
        features["init_fact_scores"] = create_float_feature(init_fact_scores)

        if self.is_training:
            features["answer_entities"] = create_int_feature(feature.answer_entity)
            features["exclude_set"] = create_int_feature(feature.exclude_set)
            # TODO: add a hp (10) to limit the num of sup facts
            max_sup_fact_num = None  # or None
            sup_fact_1hop_ids = list(set([x[0] for x in feature.sup_fact_1hop]))[:max_sup_fact_num]
            sup_fact_2hop_ids = list(set([x[0] for x in feature.sup_fact_2hop]))[:max_sup_fact_num]
            features["sup_fact_1hop_ids"] = create_int_feature(sup_fact_1hop_ids)
            features["sup_fact_2hop_ids"] = create_int_feature(sup_fact_2hop_ids)

        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        self._writer.write(tf_example.SerializeToString())

    def close(self):
        self._writer.close()

4.OpenCSRDataset类

最后要介绍的类是OpenCSRDataset类。通过注释可以看出,这个类的用途是存放OpenCSR数据集。

这个类的构造函数代码如下所示,可以看出,这个类中共有6个成员变量,其中gt_file,max_qry_length和is_training是根据形参传递进行初始化的,而examples则是通过下面将会描述的read_examples()函数,根据entity2id读取in_file指向的JSON文件得到的样例结果,num_examples则代表了examples的长度。

至于剩下最后一个成员变量,则通过下述方法进行定义。首先判断是否在训练中(is_training),是则对样例随机化。然后将形参中的tfrecord_filename指向的文件通过打开之前描述过的writer的方式将其写入TFRecord文件中,最后构建好names_to_features的字典后,使用之前描述过的input_fn_builder()函数,构造出该类的最后一个成员变量input_fn。

class OpenCSRDataset(object):
    """Reads the open commonsense reasoning dataset and converts to TFRecords."""

    def __init__(self, in_file, tokenizer, subject_mention_probability,
                 max_qry_length, is_training, entity2id, tfrecord_filename):
        """Initialize dataset."""
        del subject_mention_probability

        self.gt_file = in_file
        self.max_qry_length = max_qry_length
        self.is_training = is_training

        # Read examples from JSON file.
        self.examples = self.read_examples(in_file, entity2id)
        self.num_examples = len(self.examples)

        if is_training:
            # Pre-shuffle the input to avoid having to make a very large shuffle
            # buffer in in the `input_fn`.
            rng = random.Random(12345)
            rng.shuffle(self.examples)

        # Write to TFRecords file.
        writer = FeatureWriter(
            filename=tfrecord_filename,
            is_training=self.is_training,
            has_bridge=False)
        convert_examples_to_features(
            examples=self.examples,
            tokenizer=tokenizer,
            max_query_length=self.max_qry_length,
            entity2id=entity2id,
            output_fn=writer.process_feature)
        writer.close()

        # Create input_fn.
        names_to_features = {
            "qas_ids": tf.FixedLenFeature([], tf.string),
            "qry_input_ids": tf.FixedLenFeature([self.max_qry_length], tf.int64),
            "qry_input_mask": tf.FixedLenFeature([self.max_qry_length], tf.int64),
            "qry_entity_id": tf.VarLenFeature(tf.int64),
            "init_fact_ids": tf.VarLenFeature(tf.int64),
            "init_fact_scores": tf.VarLenFeature(tf.float32),
        }
        if is_training:
            names_to_features["answer_entities"] = tf.VarLenFeature(tf.int64)
            names_to_features["exclude_set"] = tf.VarLenFeature(tf.int64)
            names_to_features["sup_fact_1hop_ids"] = tf.VarLenFeature(tf.int64)
            names_to_features["sup_fact_2hop_ids"] = tf.VarLenFeature(tf.int64)

        self.input_fn = input_fn_builder(
            input_file=tfrecord_filename,
            is_training=self.is_training,
            drop_remainder=False,
            names_to_features=names_to_features)

除了构造方法之外,这个类还有一个read_examples()的函数,其具体函数定义的代码如下所示,主要功能在于读取json文件中的内容并将其存放于一个样例列表中。函数定义其中的代码逻辑与最开始分析的一篇源代码分析中的思路比较相似,在这里就不再详细展开了。

    def read_examples(self, queries_file, entity2id):
        """Read a json file into a list of Example."""
        self.max_qry_answers = 0
        num_qrys_without_answer, num_qrys_without_all_answers = 0, 0
        num_qrys_without_entity, num_qrys_without_all_entities = 0, 0
        tf.logging.info("Reading examples from %s", queries_file)
        with tf.gfile.Open(queries_file, "r") as reader:
            examples = []
            one_hop_num = 0
            for line in tqdm(reader, desc="Reading from %s" % reader.name):
                item = json.loads(line.strip())

                qas_id = item["_id"]
                question_text = item["question"]

                question_entities = []
                for entity in item["entities"]:
                    if entity["kb_id"].lower() in entity2id:
                        question_entities.append(entity["kb_id"].lower())
                if not question_entities:
                    num_qrys_without_entity += 1
                    if self.is_training:
                        continue
                if len(question_entities) != len(item["entities"]):
                    num_qrys_without_all_entities += 1

                    # make up the format
                answer_concepts = list(set([c["kb_id"] for c in item["all_answer_concepts"]]))  # TODO: decomp?
                choice2concepts = {}
                choice2concepts[item["answer"]] = answer_concepts
                # choice2concepts = item["choice2concepts"]
                answer_txt = item["answer"]
                assert answer_txt in choice2concepts
                answer_entities = []

                if self.is_training:
                    # Training time, we use all concepts in the correct choice.
                    answer_concepts = list(
                        set([c["kb_id"] for c in item["all_answer_concepts_decomp"]]))  # TODO: decomp?
                    choice2concepts[item["answer"]] = answer_concepts
                    for answer_concept in answer_concepts:
                        if answer_concept in entity2id:
                            # TODO: add an arg for decide if only use the longest concept.
                            answer_entities.append(entity2id[answer_concept])
                else:
                    # Test time, we use unique concepts in the correct choice.
                    for answer_concept in choice2concepts[answer_txt]:
                        if answer_concept in entity2id:
                            # TODO: add an arg for decide if only use the longest concept.
                            answer_entities.append(entity2id[answer_concept])

                if len(answer_entities) > self.max_qry_answers:
                    self.max_qry_answers = len(answer_entities)
                    tf.logging.warn("%s has %d linked entities", qas_id,
                                    len(question_entities))

                if not answer_entities:
                    num_qrys_without_answer += 1
                    if self.is_training:
                        continue
                if len(answer_entities) < len(item["answer_concepts"]):
                    num_qrys_without_all_answers += 1

                # Define the exclude_entities as the question entities,
                # and the concepts mentioned by wrong choices.
                exclude_entities = question_entities[:]
                # for choice, concepts in choice2concepts.items():
                #   if choice == answer_txt:
                #     continue
                #   for non_answer_concept in concepts:
                #     if non_answer_concept in entity2id:
                #       exclude_entities.append(non_answer_concept.lower())
                init_facts = item["init_facts"]
                sup_facts = item["sup_facts"]

                if sup_facts[0] == sup_facts[1]:
                    one_hop_num += 1
                example = Example(
                    qas_id=qas_id,
                    question_text=question_text,
                    subject_entity=question_entities,
                    answer_entity=answer_entities,
                    correct_choice=answer_txt,
                    choice2concepts=choice2concepts,
                    exclude_set=exclude_entities,
                    init_facts=init_facts,
                    sup_facts=sup_facts
                )
                examples.append(example)

        tf.logging.info("Number of valid questions = %d", len(examples))
        tf.logging.info("Number of one-hop questions = %d", one_hop_num)
        tf.logging.info("Ratio of one-hop questions = %.2f", one_hop_num / len(examples))
        tf.logging.info("Questions without any answer = %d",
                        num_qrys_without_answer)
        tf.logging.info("Questions without all answers = %d",
                        num_qrys_without_all_answers)
        tf.logging.info("Questions without any entity = %d",
                        num_qrys_without_entity)
        tf.logging.info("Questions without all entities = %d",
                        num_qrys_without_all_entities)
        tf.logging.info("Maximum answers per question = %d", self.max_qry_answers)

        return examples

二、model_fns.py模块

model_fns.py模块同样也是一个借鉴了不少来自于DrKit模型内容的模块,不过其中的内容实在太多,因此在这里仅仅介绍其中最为突出的一个函数——multi_hop_fact()函数,其具体函数定义如下所示。

从注释中我们可以知道,这个函数的作用是输入的事实数据经过多轮的增殖传播之后,得到最终的输出事实集合。

在这个函数中,首先将概念集合中的词汇全部通过BERT编码转化为词向量形式。

接下来使用MIPS(最大内积搜索),构建出一个特征数据库。

最后,得到一个没有评分的问题实体。

通过这个步骤我们可以看出这应该是一个模拟Fact-Follow迭代的过程。

def multi_hop_fact(qry_input_ids,
                   qry_input_mask,
                   qry_entity_ids,
                   init_fact_ids,
                   init_fact_scores,
                   entity_ids,
                   entity_mask,
                   ent2fact_ind,
                   ent2fact_val,
                   fact2ent_ind,
                   fact2ent_val,
                   fact2fact_ind,
                   fact2fact_val,
                   is_training,
                   use_one_hot_embeddings,
                   bert_config,
                   qa_config,
                   fact_mips_config,
                   num_hops,
                   exclude_set=None,
                   is_printing=True,
                   sup_fact_index_list=None,
                   ):
  """Multi-hops of propagation from input to output facts.

  Args:
    qry_input_ids:
    qry_input_mask:
    qry_entity_ids:
    entity_ids: (entity_word_ids) [num_entities, max_entity_len] Tensor holding
      word ids of each entity.
    entity_mask: (entity_word_masks) [num_entities, max_entity_len] Tensor with
      masks into word ids above.
    ent2fact_ind:
    ent2fact_val:
    fact2ent_ind:
    fact2ent_val:
    fact2fact_ind:
    fact2fact_val:
    is_training:
    use_one_hot_embeddings:
    bert_config:
    qa_config:
    fact_mips_config:
    num_hops:
    exclude_set:
    is_printing:
    sup_fact_index_list:

  Returns:
    layer_entities:
    layer_facts:
    layer_dense:
    layer_sp:
    batch_entities_nosc:
    qry_seq_emb:
  """
  del exclude_set  # Not used for now.

  # for embedding of concepts
  with tf.variable_scope("qry/bow"):
    # Note: trainable word weights over the BERT vocab for encoding concepts
    word_weights = tf.get_variable(
        "word_weights", [bert_config.vocab_size, 1],
        dtype=tf.float32,
        initializer=tf.ones_initializer())  # Inited as all ones.
  

  # MIPS search for facts.  Build fact feature Database
  with tf.device("/cpu:0"):
    tf_fact_db, fact_mips_search_fn = search_utils.create_mips_searcher(
        fact_mips_config.ckpt_var_name,
        # [fact_mips_config.num_facts, fact_mips_config.emb_size],
        fact_mips_config.ckpt_path,
        fact_mips_config.num_neighbors,
        local_var_name="scam_init_barrier_fact")

  qry_seq_emb, word_emb_table, qry_hidden_size = model_utils.shared_qry_encoder_v2(
      qry_input_ids, qry_input_mask, is_training, use_one_hot_embeddings,
      bert_config, qa_config)

  batch_size = tf.shape(qry_input_ids)[0]
  # Get question entities w/o scores.
  batch_qry_entities = tf.SparseTensor(
      indices=tf.concat([
          qry_entity_ids.indices[:, 0:1],
          tf.cast(tf.expand_dims(qry_entity_ids.values, 1), tf.int64)
      ],
                        axis=1),
      values=tf.ones_like(qry_entity_ids.values, dtype=tf.float32),
      dense_shape=[batch_size, qa_config.num_entities])

综上,这就是本次代码分析的全部内容,着重描述了input_fns和model_fns这两个借鉴了DrKit模型的模块的全部/部分工作原理。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值