Scalable Multi-Hop Relational Reasoning for Knowledge-Aware Question Answering
将外部知识融入模型进行推理学习,在CommonsenseQA数据集中取得SOTA的效果。
采用的外部知识为conceptNet
代码分为5步:
- 下载相关数据集
- 对数据集进行预处理
- 超参数搜索(可选)
- 训练
- 评估
本部分主要讲解第2部分----数据集预处理 。
上一部分,解析了有关csqa训练数据相关处理流程,
由于有关解析csqa内容较为繁多,希望大家有耐心
根据head tail在concept图中寻找paths
merged_relations = [ 'antonym', 'atlocation', 'capableof', 'causes', 'createdby', 'isa', 'desires',
'hassubevent', 'partof','hascontext','hasproperty','madeof', 'notcapableof',
'notdesires', 'receivesaction', 'relatedto','usedfor',]
def load_resources(cpnet_vocab_path): # 加载concept vocab词汇表
global concept2id, id2concept, relation2id, id2relation
with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
id2concept = [w.strip() for w in fin] # ['ab_extra', 'ab_intra', 'abactinal',...]
concept2id = {w: i for i, w in enumerate(id2concept)} # {'ab_extra':0, 'ab_intra':1, 'abactinal':2}
id2relation = merged_relations
relation2id = {r: i for i, r in enumerate(id2relation)}
加载图的时候,将图中带权重的节点进行遍历
def load_cpnet(cpnet_graph_path): #加载copnet图
global cpnet, cpnet_simple
cpnet = nx.read_gpickle(cpnet_graph_path)
cpnet_simple = nx.Graph()
for u, v, data in cpnet.edges(data=True):
w = data['weight'] if 'weight' in data else 1.0
if cpnet_simple.has_edge(u, v): # 如果已经存在关系 则权重累计
cpnet_simple[u][v]['weight'] += w
else:
cpnet_simple.add_edge(u, v, weight=w)
主流程 寻找图中paths
def find_paths(grounded_path, cpnet_vocab_path, cpnet_graph_path, output_path, num_processes=1, random_state=0):
"""
grounded_path: 被筛选的训练数据
cpnet_vocab_path: cpnet vocab 词汇表
cpnet_graph_path: pruned-graph 裁剪图 不含blocklist词的图
output_path: 输出保存的训练语料
"""
print(f'generating paths for {grounded_path}...')
random.seed(random_state)
np.random.seed(random_state)
global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet
if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]):
load_resources(cpnet_vocab_path) # 加载cpnet vocab 并建立映射
if cpnet is None or cpnet_simple is None:
load_cpnet(cpnet_graph_path) #加载
with open(grounded_path, 'r', encoding='utf-8') as fin:
data = [json.loads(line) for line in fin]
data = [[item["ac"], item["qc"]] for item in data] # 只包含了 token 词汇 [['ignore'], ['blow', 'change', 'effort', 'efforts', 'punish,...]]
with Pool(num_processes) as p, open(output_path, 'w', encoding='utf-8') as fout:
for pfr_qa in tqdm(p.imap(find_paths_qa_pair, data), total=len(data)):
fout.write(json.dumps(pfr_qa) + '\n')
print(f'paths saved to {output_path}')
print()
寻找问题答案对在conceptnet中的路径
def find_paths_qa_pair(qa_pair):
acs, qcs = qa_pair # [item["ac"], item["qc"]
pfr_qa = []
for ac in acs: # 对答案每一个token进行遍历
for qc in qcs: #对问题每一个token进行遍历
pf_res = find_paths_qa_concept_pair(qc, ac)
pfr_qa.append({"ac": ac, "qc": qc, "pf_res": pf_res})
return pfr_qa
从图中寻找相关路径和关系并作筛选
def find_paths_qa_concept_pair(source: str, target: str, ifprint=False):
"""
find paths for a (question concept, answer concept) pair
source and target is text
"""
global cpnet, cpnet_simple, concept2id, id2concept, relation2id, id2relation
s = concept2id[source]
t = concept2id[target]
if s not in cpnet_simple.nodes() or t not in cpnet_simple.nodes():#不在图中则返回
return
# all_path = []
# all_path_set = set()
# for max_len in range(1, 5):
# for p in nx.all_simple_paths(cpnet_simple, source=s, target=t, cutoff=max_len):
# path_str = "-".join([str(c) for c in p])
# if path_str not in all_path_set:
# all_path_set.add(path_str)
# all_path.append(p)
# print(len(p), path_str)
# if len(all_path) >= 100: # top shortest 100 paths
# break
# if len(all_path) >= 100: # top shortest 100 paths
# break
# all_path.sort(key=len, reverse=False)
"""
nx.shortest_simple_paths
Generate all simple paths in the graph G from source to target,
starting from shortest ones.
"""
all_path = [] # 从图中搜索到top100的路
try:
for p in nx.shortest_simple_paths(cpnet_simple, source=s, target=t): # p: [244, 2454, 3445,...,247]
if len(p) > 5 or len(all_path) >= 100: # top 100 paths path节点不能超过5
break
if len(p) >= 2: # skip paths of length 1
all_path.append(p)
except nx.exception.NetworkXNoPath:
pass
pf_res = []
for p in all_path:
# print([id2concept[i] for i in p])
rl = []
for src in range(len(p) - 1): # 两个节点之前边的关系
src_concept = p[src]
tgt_concept = p[src + 1]
rel_list = get_edge(src_concept, tgt_concept) # 获取边的关系
rl.append(rel_list)
if ifprint: #是否打印
rel_list_str = []
for rel in rel_list:# 将关系转为文本内容
if rel < len(id2relation):
rel_list_str.append(id2relation[rel])
else:
rel_list_str.append(id2relation[rel - len(id2relation)] + "*")
print(id2concept[src_concept], "----[%s]---> " % ("/".join(rel_list_str)), end="")
if src + 1 == len(p) - 1:
print(id2concept[tgt_concept], end="")
if ifprint:
print()
pf_res.append({"path": p, "rel": rl}) # 加入path和rel
return pf_res
对关系进行去重
def get_edge(src_concept, tgt_concept):
global cpnet
rel_list = cpnet[src_concept][tgt_concept] # list of dicts
seen = set()
res = [r['rel'] for r in rel_list.values() if r['rel'] not in seen and (seen.add(r['rel']) or True)] # get unique values from rel_list
return res
对find的paths进行打分
打分函数主流程,同find_paths的流程类型,先加载相关词向量和映射关系,然后对paths进行打分,打分函数使用的是scipy的cosine相似度,对矩阵进行打分
def score_paths(raw_paths_path, concept_emb_path, rel_emb_path, cpnet_vocab_path, output_path, num_processes=1, method='triple_cls'):
"""
raw_paths_path: find_paths所找到的path
concept_emb_path: transe 的entity
rel_emb_path: transe 的rel
cpnet_vocab_path: cpnet vocab 词汇表
output_path: path打分并保存
"""
print(f'scoring paths for {raw_paths_path}...')
global concept2id, id2concept, relation2id, id2relation
if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]):
load_resources(cpnet_vocab_path)
global concept_embs, relation_embs
if concept_embs is None:
concept_embs = np.load(concept_emb_path)# 加载transe的entity
if relation_embs is None:
relation_embs = np.load(rel_emb_path)# 加载transe的rel
if method != 'triple_cls':
raise NotImplementedError()
all_scores = []
with open(raw_paths_path, 'r', encoding='utf-8') as fin:
data = [json.loads(line) for line in fin]
with Pool(num_processes) as p, open(output_path, 'w', encoding='utf-8') as fout:
for statement_scores in tqdm(p.imap(score_qa_pairs, data), total=len(data)):
fout.write(json.dumps(statement_scores) + '\n')
print(f'path scores saved to {output_path}')
print()
scores_paths 遍历, 对每个路径进行打分
def score_qa_pairs(qa_pairs):
statement_scores = []
for qas in qa_pairs:
statement_paths = qas["pf_res"] # 获取路径和rel [{'path':[13709, 61996, 2929], 'rel':[[32], [7]]}]
if statement_paths is not None:
path_scores = []
for path in statement_paths:
assert len(path["path"]) > 1
score = score_triples(concept_id=path["path"], relation_id=path["rel"]) # 对关系和路径进行打分
path_scores.append(score)
statement_scores.append(path_scores)
else:
statement_scores.append(None)
return statement_scores
获取concept中path embedding的表示和关系表示,对每个path中的实体两两打分
def score_triples(concept_id, relation_id, debug=False):
"""
concept_id: 头尾实体的路径 [13709, 61996, 2929]
relation_id: 路径之间的关系 [[32], [7]]
"""
global relation_embs, concept_embs, id2relation, id2concept
concept = concept_embs[concept_id] # 获取图表示
relation = []
flag = []
for i in range(len(relation_id)): # 遍历关系
embs = []
l_flag = []
if 0 in relation_id[i] and 17 not in relation_id[i]: # 特定关系的获取 17代表什么关系?
relation_id[i].append(17)
elif 17 in relation_id[i] and 0 not in relation_id[i]:
relation_id[i].append(0)
if 15 in relation_id[i] and 32 not in relation_id[i]: # 特定关系的获取 15和32代表什么关系?
relation_id[i].append(32)
elif 32 in relation_id[i] and 15 not in relation_id[i]:
relation_id[i].append(15)
for j in range(len(relation_id[i])):
if relation_id[i][j] >= 17: # 如果关系id在17之后
embs.append(relation_embs[relation_id[i][j] - 17]) # 获取17之前的id
l_flag.append(1)
else:
embs.append(relation_embs[relation_id[i][j]]) # 如果关系id在17之前
l_flag.append(0)
relation.append(embs)
flag.append(l_flag)
res = 1
for i in range(concept.shape[0] - 1): # 遍历图embedding
h = concept[i]
t = concept[i + 1]
score = score_triple(h, t, relation[i], flag[i]) # 打分
res *= score # 对path的score进行累积
if debug:
print("Num of concepts:")
print(len(concept_id))
to_print = ""
for i in range(concept.shape[0] - 1):
h = id2concept[concept_id[i]]
to_print += h + "\t"
for rel in relation_id[i]:
if rel >= 17:
# 'r-' means reverse
to_print += ("r-" + id2relation[rel - 17] + "/ ")
else:
to_print += id2relation[rel] + "/ "
to_print += id2concept[concept_id[-1]]
print(to_print)
print("Likelihood: " + str(res) + "\n")
return res
计算方法为余弦相似度
def score_triple(h, t, r, flag):
"""
h: 头实体embedding
t: 尾实体embedding
r: 对应关系的embedding
flag: 关系id在17之后为1, 在17之前为0
"""
res = -10
for i in range(len(r)):
if flag[i]:
temp_h, temp_t = t, h # 头尾实体交换
else:
temp_h, temp_t = h, t
# result = (cosine_sim + 1) / 2
res = max(res, (1 + 1 - spatial.distance.cosine(r[i], temp_t - temp_h)) / 2) # 计算关系embedding与尾实体-头实体embedding的余弦相似度
return res