Scalable Multi-Hop Relational Reasoning for Knowledge-Aware Question Answering
将外部知识融入模型进行推理学习,在CommonsenseQA数据集中取得SOTA的效果。
采用的外部知识为conceptNet
代码分为5步:
- 下载相关数据集
- 对数据集进行预处理
- 超参数搜索(可选)
- 训练
- 评估
本部分主要讲解第2部分----数据集预处理 。
将词向量文本转为cache文件保存
def load_vectors(path, skip_head=False, add_special_tokens=None, random_state=0):
vocab = []
vectors = None
nrow = sum(1 for line in open(path, 'r', encoding='utf-8'))
with open(path, "r", encoding="utf8") as fin:
if skip_head:
fin.readline()
for i, line in tqdm(enumerate(fin), total=nrow):
elements = line.strip().split(" ")
word = elements[0].lower()
vec = np.array(elements[1:], dtype=float)
vocab.append(word) # 获取所有的词
if vectors is None:
vectors = np.zeros((nrow, len(vec)), dtype=np.float64)
vectors[i] = vec # 获取所有向量
np.random.seed(random_state)
n_special = 0 if add_special_tokens is None else len(add_special_tokens) # n_special 特殊词的长度
add_vectors = np.random.normal(np.mean(vectors), np.std(vectors), size=(n_special, vectors.shape[1])) # 从正态(高斯)分布中抽取随机样本。
vectors = np.concatenate((vectors, add_vectors), 0)
vocab += add_special_tokens
return vocab, vectors
def glove2npy(glove_path, output_npy_path, output_vocab_path, skip_head=False,
add_special_tokens=EXTRA_TOKS, random_state=0):
print('binarizing GloVe embeddings...')
vocab, vectors = load_vectors(glove_path, skip_head=skip_head,
add_special_tokens=add_special_tokens, random_state=random_state)
np.save(output_npy_path, vectors) # 保存向量
with open(output_vocab_path, "w", encoding='utf-8') as fout: # 保存vocab
for word in vocab:
fout.write(word + '\n')
print(f'Binarized GloVe embeddings saved to {output_npy_path}')
print(f'GloVe vocab saved to {output_vocab_path}')
print()
glove.6B.300d.txt例子如下:
一个单词 空格 300维度的向量
代码中np.random.normal是从正态(高斯)分布中抽取随机样本。
均值是np.mean(vectors) ,方差为np.std(vectors) ,得到维度为***(n_special, vectors.shape[1])***
>>> import numpy as np
>>> mu, sigma = 0, 0.1
>>> s = np.random.normal(mu, sigma, size=(1,2))
>>> s
array([[-0.1668889 , 0.08364912]])
处理conceptNet csv文件 基于en词语生成图实体vocab文件
def extract_english(conceptnet_path, output_csv_path, output_vocab_path):
"""
Reads original conceptnet csv file and extracts all English relations (head and tail are both English entities) into
a new file, with the following format for each line: <relation> <head> <tail> <weight>.
:return:
"""
print('extracting English concepts and relations from ConceptNet...')
relation_mapping = load_merge_relation() #合并同类型的关系
num_lines = sum(1 for line in open(conceptnet_path, 'r', encoding='utf-8')) # 获取所有行数
cpnet_vocab = []
concepts_seen = set()
with open(conceptnet_path, 'r', encoding="utf8") as fin, \
open(output_csv_path, 'w', encoding="utf8") as fout:
for line in tqdm(fin, total=num_lines):
"""
line: '/a/[/r/Antonym/,/c/ab/агыруа/n/,/c/ab/аҧсуа/]\t/r/Antonym\t/c/ab/агыруа/n\t/c/ab/аҧсуа\t{"dataset":
"/d/wiktionary/en", "license": "cc:by-sa/4.0", "sources": [{"contributor": "/s/resource/wiktionary/en",
"process": "/s/process/wikiparsec/1"}], "weight": 1.0}\n'
"""
toks = line.strip().split('\t')
if toks[2].startswith('/c/en/') and toks[3].startswith('/c/en/'): # 寻找英语的词
"""
Some preprocessing:
- Remove part-of-speech encoding.
- Split("/")[-1] to trim the "/c/en/" and just get the entity name, convert all to
- Lowercase for uniformity.
"""
rel = toks[1].split("/")[-1].lower() # toks[1]:'/r/Antonym' rel:antonym
head = del_pos(toks[2]).split("/")[-1].lower() # del_pos 删除pos词性
tail = del_pos(toks[3]).split("/")[-1].lower() #
if not head.replace("_", "").replace("-", "").isalpha(): # 保证head 和 tail 为字母 也即词语
continue
if not tail.replace("_", "").replace("-", "").isalpha(): # 保证head 和 tail 为字母 也即词语
continue
if rel not in relation_mapping: # rel 不在relation_mapping 需要continue
continue
rel = relation_mapping[rel]
if rel.startswith("*"): #代表反向
head, tail, rel = tail, head, rel[1:]
data = json.loads(toks[4])
# '{"dataset": "/d/wiktionary/en", "license": "cc:by-sa/4.0", "sources": [{"contributor": "/s/resource/wiktionary/en", "process": "/s/process/wikiparsec/1"}], "weight": 1.0}'
fout.write('\t'.join([rel, head, tail, str(data["weight"])]) + '\n')
for w in [head, tail]:
if w not in concepts_seen: #记录被遍历实体 去重
concepts_seen.add(w)
cpnet_vocab.append(w)
with open(output_vocab_path, 'w', encoding='utf-8') as fout:
for word in cpnet_vocab:
fout.write(word + '\n')
print(f'extracted ConceptNet csv file saved to {output_csv_path}')
print(f'extracted concept vocabulary saved to {output_vocab_path}')
print()
其中*** load_merge_relation() *** 合并同类关系。
>>> relation_groups = [
... 'atlocation/locatednear',
... 'capableof',
... 'causes/causesdesire/*motivatedbygoal',
... 'createdby',
... 'desires',
... 'antonym/distinctfrom',
... 'hascontext',
... ]
relation_mapping = load_merge_relation()
>>> relation_mapping
{'atlocation': 'atlocation', 'locatednear': 'atlocation', 'capableof': 'capableof', 'causes': 'causes', 'causesdesire': 'causes', 'motivatedbygoal': '*causes', 'createdby': 'createdby', 'desires': 'desires', 'antonym': 'antonym', 'distinctfrom': 'antonym', 'hascontext': 'hascontext', 'hasproperty': 'hasproperty', 'hassubevent': 'hassubevent', 'hasfirstsubevent': 'hassubevent', 'haslastsubevent': 'hassubevent', 'hasprerequisite': 'hassubevent', 'entails': 'hassubevent', 'mannerof': 'hassubevent', 'isa': 'isa', 'instanceof': 'isa', 'definedas': 'isa', 'madeof': 'madeof', 'notcapableof': 'notcapableof', 'notdesires': 'notdesires', 'partof': 'partof', 'hasa': '*partof', 'relatedto': 'relatedto', 'similarto': 'relatedto', 'synonym': 'relatedto', 'usedfor': 'usedfor', 'receivesaction': 'receivesaction'}
如果遇到 *** nltk.download(‘stopwords’) *** 无法下载stopwords,可以手动从github官网上下载
https://github.com/nltk/nltk_data/blob/gh-pages/packages/corpora/stopwords.zip
将文件解压并放在报错提示的路径下,即可解决报错。
基于词向量获取图实体的embedding
def load_vectors_from_npy_with_vocab(glove_npy_path, glove_vocab_path, vocab, verbose=True, save_path=None):
with open(glove_vocab_path, 'r', encoding='utf-8') as fin:
glove_w2idx = {line.strip(): i for i, line in enumerate(fin)} # 将词向量与index建立映射
glove_emb = np.load(glove_npy_path) # 加载词向量
vectors = np.zeros((len(vocab), glove_emb.shape[1]), dtype=float) # 生成和预训练词向量维度一致的矩阵
oov_cnt = 0
for i, word in enumerate(vocab):
if word in glove_w2idx:
vectors[i] = glove_emb[glove_w2idx[word]] #获取图实体词的embedding
else:
oov_cnt += 1
if verbose:
print(len(vocab))
print('embedding oov rate: {:.4f}'.format(oov_cnt / len(vocab)))
if save_path is None:
return vectors
np.save(save_path, vectors)
def load_pretrained_embeddings(glove_npy_path, glove_vocab_path, vocab_path, verbose=True, save_path=None):
"""
glove_npy_path:词向量文件
glove_vocab_path:与词向量对应的vocab文件
vocab_path: 图实体中vocab集合
save_path: 基于词向量生成 图实体中vocab的词向量文件
"""
vocab = []
with open(vocab_path, 'r', encoding='utf-8') as fin:
for line in fin.readlines():
vocab.append(line.strip())
load_vectors_from_npy_with_vocab(glove_npy_path=glove_npy_path, glove_vocab_path=glove_vocab_path, vocab=vocab, verbose=verbose, save_path=save_path)
基于关系csv文件构建图
def construct_graph(cpnet_csv_path, cpnet_vocab_path, output_path, prune=True):
"""
cpnet_csv_path: cpent生成rel、head、tail、weight的csv文件
cpnet_vocab_path: 删选出图中实体 vocab的集合
output_path: 生成图的路径
prune: 是否剪枝
"""
print('generating ConceptNet graph file...')
nltk.download('stopwords')
nltk_stopwords = nltk.corpus.stopwords.words('english')
nltk_stopwords += ["like", "gone", "did", "going", "would", "could",
"get", "in", "up", "may", "wanter"] # issue: mismatch with the stop words in grouding.py
blacklist = set(["uk", "us", "take", "make", "object", "person", "people"]) # issue: mismatch with the blacklist in grouding.py
concept2id = {}
id2concept = {}
with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
id2concept = [w.strip() for w in fin]
concept2id = {w: i for i, w in enumerate(id2concept)}
id2relation = merged_relations
relation2id = {r: i for i, r in enumerate(id2relation)}
graph = nx.MultiDiGraph()
nrow = sum(1 for _ in open(cpnet_csv_path, 'r', encoding='utf-8'))
with open(cpnet_csv_path, "r", encoding="utf8") as fin:
def not_save(cpt):
if cpt in blacklist:
return True
'''originally phrases like "branch out" would not be kept in the graph'''
# for t in cpt.split("_"):
# if t in nltk_stopwords:
# return True
return False
attrs = set()
for line in tqdm(fin, total=nrow):
ls = line.strip().split('\t')
rel = relation2id[ls[0]]
subj = concept2id[ls[1]] # head 、
obj = concept2id[ls[2]] # tail
weight = float(ls[3])
if prune and (not_save(ls[1]) or not_save(ls[2]) or id2relation[rel] == "hascontext"): # 剪枝 满足以上条件
continue
# if id2relation[rel] == "relatedto" or id2relation[rel] == "antonym":
# weight -= 0.3
# continue
if subj == obj: # delete loops
continue
# weight = 1 + float(math.exp(1 - weight)) # issue: ???
if (subj, obj, rel) not in attrs:
graph.add_edge(subj, obj, rel=rel, weight=weight)
attrs.add((subj, obj, rel))
graph.add_edge(obj, subj, rel=rel + len(relation2id), weight=weight) # 无向图 关系 当前rel-id + taotal-rel-len
attrs.add((obj, subj, rel + len(relation2id)))
nx.write_gpickle(graph, output_path)
print(f"graph file saved to {output_path}")
print()
匹配模式
def create_matcher_patterns(cpnet_vocab_path, output_path, debug=False):
"""
cpnet_vocab_path:图实体vocab
output_path:
"""
cpnet_vocab = load_cpnet_vocab(cpnet_vocab_path) # ['ab extract', 'ab intera']
nlp = spacy.load('en_core_web_sm', disable=['parser', 'ner', 'textcat'])
docs = nlp.pipe(cpnet_vocab)
all_patterns = {}
if debug:
f = open("filtered_concept.txt", "w")
for doc in tqdm(docs, total=len(cpnet_vocab)):
pattern = create_pattern(nlp, doc, debug)
if debug:
if not pattern[0]:
f.write(pattern[1] + '\n')
if pattern is None:
continue
all_patterns["_".join(doc.text.split(" "))] = pattern # 'ab_extra':[0:{'LEMMA': 'ab'} 1:{'LEMMA': 'extra'}]
print("Created " + str(len(all_patterns)) + " patterns.")
with open(output_path, "w", encoding="utf8") as fout:
json.dump(all_patterns, fout)
if debug:
f.close()