记忆网络之Key-Value Memory Networks tensorflow实现
前面我们介绍了Key-Value Memory Networks这篇论文,这里我们介绍一下该论文使用tensorflow的实现方法。其实github上面有一个实现方案,但是该方案用于仿真bAbI任务的数据集,与QA任务还有一定的区别,又与之前一篇End-to-End MemNN已经对该数据进行了仿真实现,所以这篇文章更想尝试一下QA任务的实现方案,这也是本文的主要关注点所在。其实二者的主要区别在于数据集的处理和Key-Value memory的表示,所以我们将结合https://github.com/siyuanzhao/key-value-memory-networks和https://github.com/dapurv5/neural_kbqa(主要借鉴其数据处理部分)两个代码进行介绍。最终整合之后的完整代码可以到我的github上面进行下载~~接下来我们从数据处理、模型实现、模型训练几个部分进行介绍。
数据处理
这里使用论文中提出的MovieQA数据集,可以到Facebook官方github上面进行下载,此外,其实官网也给出了torch的实现方案,有兴趣的同学也可以进行参考学习。数据下载地址如下所示:
http://www.thespermwhale.com/jaseweston/babi/movieqa.tar.gz
解压之后发现主要包括knowledge_source和questions两个文件夹,分别保存了知识库和QA对。这里为了方便起见,我们使用KB作为知识源,而不使用wiki文章,因为wiki文章处理起来过于复杂,可以看看官网的代码,而相比之下KB结构相对比较简单,做key hashing等操作借助图的数据结构还可以方便的实现。这里也主要对该操作进行介绍,其他的都可以参考上面那个连接进行了解,当然也可以直接使用我github上面分享出来的处理完数据集,将精力放在模型构建上面。下面我们看一下如何将KB知识库构建成一个知识图谱的形式(这里借助networkx第三方库):
HIGH_DEGREE_THRESHOLD = 50
class KnowledgeGraph(object):
def __init__(self, graph_path, unidirectional=True):
"""
初始化知识图谱,如果unidirectional==False,也就是使用反向三元组,将(e2,invR,e1)也添加到图中
"""
# 初始化一个图结构
self.G = nx.DiGraph()
# 读取知识库中的每个三元组,并将其添加到图中
with open(graph_path, 'r') as graph_file:
for line in graph_file:
line = clean_line(line)
e1, relation, e2 = line.split(PIPE)
self.G.add_edge(e1, e2, {
"relation": relation})
# 将反向三元组添加到图中
if not unidirectional:
self.G.add_edge(e2, e1, {
"relation": self.get_inverse_relation(relation)})
# 记录图中所有的高度节点,如果节点的入度大于HIGH_DEGREE_THRESHOLD,则为高度节点
self.high_degree_nodes = set([])
indeg = self.G.in_degree()
for v in indeg:
if indeg[v] > HIGH_DEGREE_THRESHOLD:
self.high_degree_nodes.add(v)
self.all_entities = set(nx.nodes(self.G))
def get_inverse_relation(self, relation):
return "INV_" + relation
def get_all_paths(self, source, target, cutoff):
'''
得到source和target之间的所有的路径。
:param source: 起始节点
:param target: 终止节点
:param cutoff: 是否启用cutoff
:return: 返回两个列表,一个是使用节点表示的path,另一个是使用边表示的path。如下所示:
[ [e1, e2], [e1, e3, e2]], [[r1], [r2, r3] ]
'''
if source == target:
return [], []
paths_of_entities = []
paths_of_relations = []
#遍历源和目的之间所有的path
for path in nx.all_simple_paths(self.G, source, target, cutoff):
#可以将path直接添加到paths_of_entities中
paths_of_entities.append(path)
relations_path = []
#对上面每个path取出relation,并组成新的path,然后添加到paths_of_relations中
for i in range(0, len(path) - 1):
relation = self.G[path[i]][path[i + 1]]['relation']
relations_path.append(relation)
paths_of_relations.append(relations_path)
return paths_of_entities, paths_of_relations
def get_candidate_neighbors(self, node, num_hops=2, avoid_high_degree_nodes=True):
'''
得到与一个节点相关的周边跳数在num_hops之内的所有节点
:param node:要寻找的节点
:param num_hops:跳数
:param avoid_high_degree_nodes:是否去除高度节点
:return:返回得到的所有节点
'''
result = set([])
q = [node]
visited = set([node])
dist = {node: 0}
while len(q) > 0:
u = q.pop(0)
result.add(u)
for nbr in self.G.neighbors(u):
if nbr in self.high_degree_nodes and avoid_high_degree_nodes:
continue
if nbr not in visited:
visited.add(nbr)
dist[nbr] = dist[u] + 1
if dist[nbr] <= num_hops:
q.append(nbr)
result.remove(node)