Scalable Multi-Hop Relational Reasoning for Knowledge-Aware Question Answering
将外部知识融入模型进行推理学习,在CommonsenseQA数据集中取得SOTA的效果。
采用的外部知识为conceptNet
代码分为5步:
- 下载相关数据集
- 对数据集进行预处理
- 超参数搜索(可选)
- 训练
- 评估
本部分主要讲解第2部分----数据集预处理 。
上一部分,解析了有关csqa训练数据相关处理流程,
本节继续解析该部分内容
基于concepts vocab抽取出训练数据中的concepts
def ground(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes=1, debug=False):
# global PATTERN_PATH, CPNET_VOCAB
# if PATTERN_PATH is None:
# PATTERN_PATH = pattern_path
# CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)
# 加载csqa数据,
sents = []
answers = []
with open(statement_path, 'r', encoding='utf-8') as fin:
lines = [line for line in fin]
if debug:
lines = lines[192:195]
print(len(lines))
for line in lines:
if line == "":
continue
j = json.loads(line)
for statement in j["statements"]: # 文本蕴含形式的文本内容
sents.append(statement["statement"])
for answer in j["question"]["choices"]:
ans = answer['text']
# ans = " ".join(answer['text'].split("_"))
try:
assert all([i != "_" for i in ans])
except Exception:
print(ans)
answers.append(ans)
# 与conceptnet进行匹配
res = match_mentioned_concepts(sents, answers, num_processes, pattern_path, cpnet_vocab_path)
res = prune(res, cpnet_vocab_path) # 剪枝
# check_path(output_path)
with open(output_path, 'w', encoding='utf-8') as fout:
for dic in res:
fout.write(json.dumps(dic) + '\n')
print(f'grounded concepts saved to {output_path}')
print()
其中match_mentioned_concepts是抽取出问题与答案中的concepts,采用多进程进行抽取
有关imap使用方法:参考这篇imap方法解析
def match_mentioned_concepts(sents, answers, num_processes, pattern_path, cpnet_vocab_path):
input_args = []
for s, a in zip(sents, answers):
input_args.append((s, a, pattern_path, cpnet_vocab_path))
with Pool(num_processes) as p:
res = list(tqdm(p.imap(ground_qa_pair, input_args), total=len(sents))) #匹配与concept词向量相关的训练语料
return res
抽取核心函数
def ground_qa_pair(qa_pair):
s, a, pattern_path, cpnet_vocab_path = qa_pair
global nlp, matcher, CPNET_VOCAB
if nlp is None or matcher is None or CPNET_VOCAB is None:
nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat'])#加载英文模型
nlp.add_pipe(nlp.create_pipe('sentencizer')) #增加自定义pipeline component
matcher = load_matcher(nlp, pattern_path) #增加匹配模型
CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)
all_concepts = ground_mentioned_concepts(nlp, matcher, s, a) # 抽取文本中出 lemma较短的一个concepts
answer_concepts = ground_mentioned_concepts(nlp, matcher, a) # 抽取答案中出 lemma较短的一个concepts
question_concepts = all_concepts - answer_concepts # 得到问题的concepts
if len(question_concepts) == 0:
question_concepts = hard_ground(nlp, s, CPNET_VOCAB) # not very possible
if len(answer_concepts) == 0:
answer_concepts = hard_ground(nlp, a, CPNET_VOCAB) # some case
# question_concepts = question_concepts - answer_concepts
question_concepts = sorted(list(question_concepts))
answer_concepts = sorted(list(answer_concepts))
return {"sent": s, "ans": a, "qc": question_concepts, "ac": answer_concepts}
抽取concepts
其中有关matcher的用法,可以参考:Matcher
def ground_mentioned_concepts(nlp, matcher, s, ans=None): # s 为文本 a 为候选答案
s = s.lower()
doc = nlp(s)
matches = matcher(doc)
mentioned_concepts = set()
span_to_concepts = {}
if ans is not None:
ans_matcher = Matcher(nlp.vocab)
ans_words = nlp(ans)
# print(ans_words)
ans_matcher.add(ans, None, [{'TEXT': token.text.lower()} for token in ans_words]) # 先分词 后加入matcher中
ans_match = ans_matcher(doc) # 对s文本进行匹配
ans_mentions = set()
for _, ans_start, ans_end in ans_match:
ans_mentions.add((ans_start, ans_end))
for match_id, start, end in matches:
if ans is not None:
if (start, end) in ans_mentions:
continue
span = doc[start:end].text # the matched span
# a word that appears in answer is not considered as a mention in the question
# if len(set(span.split(" ")).intersection(set(ans.split(" ")))) > 0:
# continue
original_concept = nlp.vocab.strings[match_id] # 获取文本内容
original_concept_set = set()
original_concept_set.add(original_concept) # 原文文本 set
# print("span", span)
# print("concept", original_concept)
# print("Matched '" + span + "' to the rule '" + string_id)
# why do you lemmatize a mention whose len == 1?
if len(original_concept.split("_")) == 1:
# tag = doc[start].tag_
# if tag in ['VBN', 'VBG']:
original_concept_set.update(lemmatize(nlp, nlp.vocab.strings[match_id])) #按照_进行lemma划分
if span not in span_to_concepts:
span_to_concepts[span] = set()
# 分词后与原始文本的对照
span_to_concepts[span].update(original_concept_set) # {'accelerator': {'accelerators', 'accelerator'}, 'controller': {'controller'}}
for span, concepts in span_to_concepts.items():
concepts_sorted = list(concepts) #转为list
# print("span:")
# print(span)
# print("concept_sorted:")
# print(concepts_sorted)
concepts_sorted.sort(key=len) # 按照长度进行排序
# mentioned_concepts.update(concepts_sorted[0:2])
shortest = concepts_sorted[0:3] # 选择前三个
for c in shortest:
if c in blacklist: # 不需要词会被 continue
continue
# a set with one string like: set("like_apples")
lcs = lemmatize(nlp, c) # 将空格替换为_ 得到lemma 如accelerators的lemma为accelerator
intersect = lcs.intersection(shortest) # 个人理解为交集
if len(intersect) > 0:
mentioned_concepts.add(list(intersect)[0])
else:
mentioned_concepts.add(c)
# if a mention exactly matches with a concept
exact_match = set([concept for concept in concepts_sorted if concept.replace("_", " ").lower() == span.lower()])
# print("exact match:")
# print(exact_match)
assert len(exact_match) < 2 # 3个只能选1个
mentioned_concepts.update(exact_match)
return mentioned_concepts
得到concepts,然后进行prune
去除停用词,只保留concepts vocab中存在的concepts
def prune(data, cpnet_vocab_path):
# reload cpnet_vocab
with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
cpnet_vocab = [l.strip() for l in fin]
prune_data = []
for item in tqdm(data):
qc = item["qc"] # 问题的concepts ['blow', 'change', 'effort']
prune_qc = []
for c in qc:
if c[-2:] == "er" and c[:-2] in qc: # 以er结尾 前两位在qc中 continue
continue
if c[-1:] == "e" and c[:-1] in qc: # 以e结尾 前一位在qc中 continue
continue
have_stop = False # 是否存在停用词
# remove all concepts having stopwords, including hard-grounded ones
for t in c.split("_"):
if t in nltk_stopwords:
have_stop = True
if not have_stop and c in cpnet_vocab: # 存在停用词就过滤掉 并且保存 c 在 concept的vocab中
prune_qc.append(c)
ac = item["ac"]
prune_ac = []
for c in ac:
if c[-2:] == "er" and c[:-2] in ac:
continue
if c[-1:] == "e" and c[:-1] in ac:
continue
all_stop = True
for t in c.split("_"):
if t not in nltk_stopwords:
all_stop = False
if not all_stop and c in cpnet_vocab:
prune_ac.append(c)
try:
assert len(prune_ac) > 0 and len(prune_qc) > 0
except Exception as e:
pass
# print("In pruning")
# print(prune_qc)
# print(prune_ac)
# print("original:")
# print(qc)
# print(ac)
# print()
item["qc"] = prune_qc
item["ac"] = prune_ac
prune_data.append(item)
return prune_data