2021SC@SDUSC
根据“基于常识知识的推理问题”的第一篇技术报告(暨综述)可知,DrFact模型的第一步需要初始化事实集合——即对q编码,通过最大内积搜索检索和q相关的事实,从这些事实中,选择包含q中的概念的事实作为初始的。由此,我展开这次的源代码分析工作。
在这次的源代码分析之前,首先先要对初始化步骤中的一些概念进行一些了解,因此我会在这次源代码分析报告中先对初始化步骤中的概念进行介绍。q和事实矩阵D自然不必多言,在初始化步骤中,最为重要的算法概念即是最大内积搜索,因此在此着重介绍一下最大内积搜索的概念。在介绍完最大内积搜索的概念之后,我再开始这次的源代码分析。
一、最大内积搜索
最大内积搜索(MIPS, Maximum Inner Product Search)是机器学习中十分常用的一个方法,其思想内核非常简单。假设你有一堆d维向量,组成集合X,现在输入了一个同样维度的查询向量q(query),请从X中找出一个p,使得p和q的点积在集合X是最大的。用公式来描述的话即是
MIPS会不由得让人想到最近邻算法(NN)。如果把上面的定义改成找一个p使得p和q的距离最小,MIPS就转变为了NN。用公式来描述(使用欧氏距离进行计算)的话即是
这里顺便提一提,假设X中的向量模长都一样,那两个问题其实是等价的。然而在很多实际场景中,例如推荐系统里的各种Embedding,以及在我目前正在学习的OpenCSR项目中,使用BERT编码后得到的句向量,都不能满足这个约束。不过,通过引入NN,我们也得以更好地理解MIPS的思想。虽然我们可以用BERT这样的大型模型获得很好的准确度,但如果用BERT直接对语料库中的所有问题进行计算,将耗费大量的时间。所以可以先用关键词检索或者向量检索从语料库里召回一些候选语料后再做高精度匹配。这也是DrFact模型第一步初始化的核心含义。
二、 源代码分析
本次分析的源文件是add_init_facts.py文件,其作用即是初始化事实集合。
2.1 调用模块
首先,add_init_facts.py文件调用了以下这些python库,均在上次源代码分析中有所体现,故不再介绍。
import json
from absl import app
from absl import flags
from absl import logging
from tqdm import tqdm
2.2 flags参数
然后是通过flags定义的全局变量——6个字符串变量以及一个整型变量。
- linked_qas_file的初始值为None,表示指向数据集文件的路径。
- drfact_format_gkb_file的初始值为None,表示指向gkb语料库的路径。
- ret_result_file的初始值为None,表示指向数据集文件的路径。
- sup_facts_file的初始值为None,表示指向数据集文件的路径。
- output_file的初始值为None,表示指向数据集文件的路径。
- split的初始值为字符串"train",表示指向数据集文件的路径。
- max_num_fact的初始值为1000,表示指向数据集文件的路径。
不过在我看来,这里的注释是存在一定问题的,比如最后那个整型参数,显然不可能是一个路径,这个问题有待日后讨论。
FLAGS = flags.FLAGS
flags.DEFINE_string("linked_qas_file", None, "Path to dataset file.")
flags.DEFINE_string("drfact_format_gkb_file", None, "Path to gkb corpus.")
flags.DEFINE_string("ret_result_file", None, "Path to dataset file.")
flags.DEFINE_string("sup_facts_file", None, "Path to dataset file.")
flags.DEFINE_string("output_file", None, "Path to dataset file.")
flags.DEFINE_string("split", "train", "Path to dataset file.")
flags.DEFINE_integer("max_num_facts", 1000, "Path to dataset file.")
2.3 主函数main
有了以上这些python库以及flags定义的变量介绍,接下来开始着重解析add_init_facts.py文件中最重要,也是唯一的一个函数方法main(_)。
主函数main中首先打开了drfact_format_gkb_file指向的文件,将其中按行拆分存放后,对该列表进行遍历。每个循环中对当前行的的内容使用json.loads()方法进行提取并存放于对象instance中,然后使用当前行事实的原有id作为字典的键,当前行索引和当前事实行对象instance作为对应的值,分别构建两个字典,其中gkb_id_to_id字典表示instance原有id和当前索引的映射,而facts_dict则是instance原有id与其内容的映射表。
def main(_):
"""Main funciton."""
logging.set_verbosity(logging.INFO)
with open(FLAGS.drfact_format_gkb_file) as f:
logging.info("Reading %s..."%f.name)
gkb_id_to_id = {}
facts_dict = {}
cur_fact_ind = 0
for line in f.read().splitlines():
instance = json.loads(line)
gkb_id_to_id[instance["id"]] = cur_fact_ind
facts_dict[instance["id"]] = instance
cur_fact_ind += 1
然后再打开ret_result_file指向的文件,将其中内容按行抽取后使用json.loads()函数进行提取,遍历存放于ret_data列表中。
with open(FLAGS.ret_result_file) as f:
logging.info("Reading %s..."%f.name)
ret_data = [json.loads(line) for line in f.read().split("\n") if line]
再打开linked_qas_file指向的文件,采用与ret_result_file相同的方式处理,结果存放于data列表中。注意,这里需要对data列表和ret_data列表的长度是否相等进行检测,如果不等则抛出异常。
with open(FLAGS.linked_qas_file) as f:
logging.info("Reading QAS(-formatted) data...%s"%f.name)
jsonlines = f.read().splitlines()
data = [json.loads(jsonline) for jsonline in jsonlines]
assert len(ret_data) == len(data)
接下来对sup_facts_file指向的文件进行操作。如果该路径存在,则将改文件打开,延续上述方法对内容按行拆分处理并存放于列表sup_facts_data中。最后,检测data列表和sup_facts_data长度是否相等,不等则抛出异常。
sup_facts_data = []
if FLAGS.sup_facts_file:
with open(FLAGS.sup_facts_file) as f:
jsonlines = f.read().splitlines()
sup_facts_data = [json.loads(jsonline) for jsonline in jsonlines]
assert len(data) == len(sup_facts_data)
接下来对data进行遍历,根据其索引获取对应于ret_data列表中对应的字典内容,再使用"results"和"all_ret_facts"键找出该字典中的字典中的all_ret_facts对应值,存放于all_ret_facts变量中。对于每轮遍历获得的样本字典ins,为其添加一个值为空列表的新键值对init_facts和一个值为二空元列表的新键值对sup_facts。
将样本字典ins中的all_answer_concepts对应值列表中的kb_id封装为集合answer_concepts,这样可以保证其中无重复元素。同理将样本字典ins中的entities对应值列表中的kb_id封装为集合形式,并与COMMON_CONCEPT字典作差,得到差集合question_concepts。
在此轮循环中,再对all_ret_facts进行遍历,根据当前相关事实的事实id将其从事实集合中找出存放于fact变量,再将该变量字典中的mentions值中所有kb_id值取出封装成集合fact_concepts。此时,如果fact_concepts和question_concepts内容存在交集,则为当前样本ins字典的init_facts值设为fid映射的当前id和s的元组。再将包含的答案设为fact_concepts和answer_concepts的交集。循环此步,直到init_facts值长度大于上限max_num_facts时停止,此时init_facts装有大量备选回答。contain_answer变为True,则本轮中num_cover+1。
然后,为ins构建值为conpcet_set长度的num_mentioned_concepts键值对,再将init_facts中内容抽取成为集合init_fact_set。接下来如果参数split为"train",则对尚为supfacts二元列表中内容进行遍历。若内容长度为1则将该二元列表中索引为0,1和0,0的内容抽取并填回当前内容列表中,若长度为2则将两个上述内容抽取并填回。若这些fid不在其init_fact_set中,则对init_fact_set进行同步补充。如果参数split不为"train"且init_facts中备选回答不足,则将事实字典facts_dict中的提到的概念补充进入init_facts中直到其充满为止。
完成以上所有工作后,算是完成对data中一个样本ins的一次处理,将该样本ins填入new_data列表中。由此,循环遍历ins直到data遍历完成,此时new_data列表中包含所有ins处理后的内容,将其写入输出文件中,路径为参数output_fle。
new_data = []
num_covered = 0
for ind, ins in tqdm(enumerate(data), desc=FLAGS.linked_qas_file, total=len(data)):
all_ret_facts = ret_data[ind]["results"]["all_ret_facts"]
ins["init_facts"] = []
ins["sup_facts"] = [[], []]
# question_concepts = set([c["kb_id"] for c in ins["entities"]])
answer_concepts = set([c["kb_id"] for c in ins["all_answer_concepts"]])
question_concepts = set([c["kb_id"] for c in ins["entities"]]) - set(COMMON_CONCEPTS)
is_covered = False
concept_set = set()
for fid, s in all_ret_facts:
fact = facts_dict[fid]
fact_concepts = set([m["kb_id"] for m in fact["mentions"]])
# TODO: this is equvilent to dense_first and then sparse
# if FLAGS.split == "train":
contain_answer = False
if fact_concepts & question_concepts:
# keep only question-mentioned facts as the first hop
ins["init_facts"].append((gkb_id_to_id[fid], s))
contain_answer = fact_concepts & answer_concepts
# elif fact_concepts & answer_concepts:
# # not mention question concept but mention answer
# # answer only concepts
# ins["sup_facts"][0].append((gkb_id_to_id[fid], s))
# ins["sup_facts"][1].append((gkb_id_to_id[fid], s))
# else:
# ins["init_facts"].append((gkb_id_to_id[fid], s))
# continue
# else:
# ins["init_facts"].append((gkb_id_to_id[fid], s))
# # if len(ins["init_facts"]) < FLAGS.max_num_facts or contain_answer: # Cause some problems
# else:
# if len(ins["init_facts"]) < FLAGS.max_num_facts:
# ins["init_facts"].append((gkb_id_to_id[fid], s))
if len(ins["init_facts"]) >= FLAGS.max_num_facts:
break
if contain_answer:
is_covered = True
concept_set.update(fact_concepts)
if is_covered:
num_covered += 1
ins["num_mentioned_concepts"] = len(concept_set)
init_fact_set = set([fid for fid, _ in ins["init_facts"]])
if FLAGS.split == "train" and sup_facts_data:
for sup_item in sup_facts_data[ind]["sup_facts"]:
item_supfacts = sup_item[0] # a list of facts
if len(item_supfacts) == 1: # One-hop quesiton
fid = item_supfacts[0][0]
score = item_supfacts[0][1]
ins["sup_facts"][0].append((fid, score))
# TODO: put the same fact to the second hop slot
ins["sup_facts"][1].append((fid, score))
if fid not in init_fact_set:
ins["init_facts"].append((fid, score))
init_fact_set.add(fid)
elif len(item_supfacts) == 2: # Two-hop quesiton
fid_1 = item_supfacts[0][0]
score_1 = item_supfacts[0][1]
fid_2 = item_supfacts[1][0]
score_2 = item_supfacts[1][1]
ins["sup_facts"][0].append((fid_1, score_1))
ins["sup_facts"][1].append((fid_2, score_2))
if fid_1 not in init_fact_set:
# Only put the the first hop as the initial facts for training time.
ins["init_facts"].append((fid_1, score_1))
# Make them unique
ins["sup_facts"][0] = list(set(ins["sup_facts"][0]))
ins["sup_facts"][1] = list(set(ins["sup_facts"][1]))
elif FLAGS.split != "train" and len(ins["init_facts"]) < FLAGS.max_num_facts:
for fid, s in all_ret_facts:
fact = facts_dict[fid]
fact_concepts = set([m["kb_id"] for m in fact["mentions"]])
ins["init_facts"].append((gkb_id_to_id[fid], s))
if len(ins["init_facts"]) >= FLAGS.max_num_facts:
break
new_data.append(ins)
with open(FLAGS.output_file, "w") as f:
logging.info("num_covered: %d", num_covered)
logging.info("len(new_data): %d", len(new_data))
logging.info("Coverage:%.2f", num_covered/len(new_data))
logging.info("Writing to %s", f.name)
f.write("\n".join([json.dumps(i) for i in new_data])+"\n")
logging.info("Done.")
综上,便是这次源代码分析的全部内容。