记忆网络之Key-Value Memory Networks tensorflow实现

记忆网络之Key-Value Memory Networks tensorflow实现

前面我们介绍了Key-Value Memory Networks这篇论文,这里我们介绍一下该论文使用tensorflow的实现方法。其实github上面有一个实现方案,但是该方案用于仿真bAbI任务的数据集,与QA任务还有一定的区别,又与之前一篇End-to-End MemNN已经对该数据进行了仿真实现,所以这篇文章更想尝试一下QA任务的实现方案,这也是本文的主要关注点所在。其实二者的主要区别在于数据集的处理和Key-Value memory的表示,所以我们将结合https://github.com/siyuanzhao/key-value-memory-networkshttps://github.com/dapurv5/neural_kbqa(主要借鉴其数据处理部分)两个代码进行介绍。最终整合之后的完整代码可以到我的github上面进行下载~~接下来我们从数据处理、模型实现、模型训练几个部分进行介绍。

数据处理

这里使用论文中提出的MovieQA数据集,可以到Facebook官方github上面进行下载,此外,其实官网也给出了torch的实现方案,有兴趣的同学也可以进行参考学习。数据下载地址如下所示:
http://www.thespermwhale.com/jaseweston/babi/movieqa.tar.gz

解压之后发现主要包括knowledge_source和questions两个文件夹,分别保存了知识库和QA对。这里为了方便起见,我们使用KB作为知识源,而不使用wiki文章,因为wiki文章处理起来过于复杂,可以看看官网的代码,而相比之下KB结构相对比较简单,做key hashing等操作借助图的数据结构还可以方便的实现。这里也主要对该操作进行介绍,其他的都可以参考上面那个连接进行了解,当然也可以直接使用我github上面分享出来的处理完数据集,将精力放在模型构建上面。下面我们看一下如何将KB知识库构建成一个知识图谱的形式(这里借助networkx第三方库):


    HIGH_DEGREE_THRESHOLD = 50
    class KnowledgeGraph(object):
        def __init__(self, graph_path, unidirectional=True):
            """
            初始化知识图谱,如果unidirectional==False,也就是使用反向三元组,将(e2,invR,e1)也添加到图中
            """
            # 初始化一个图结构
            self.G = nx.DiGraph()
            # 读取知识库中的每个三元组,并将其添加到图中
            with open(graph_path, 'r') as graph_file:
                for line in graph_file:
                    line = clean_line(line)
                    e1, relation, e2 = line.split(PIPE)
                    self.G.add_edge(e1, e2, {
  "relation": relation})
                    # 将反向三元组添加到图中
                    if not unidirectional:
                        self.G.add_edge(e2, e1, {
  "relation": self.get_inverse_relation(relation)})

            # 记录图中所有的高度节点,如果节点的入度大于HIGH_DEGREE_THRESHOLD,则为高度节点
            self.high_degree_nodes = set([])
            indeg = self.G.in_degree()
            for v in indeg:
                if indeg[v] > HIGH_DEGREE_THRESHOLD:
                    self.high_degree_nodes.add(v)
            self.all_entities = set(nx.nodes(self.G))

        def get_inverse_relation(self, relation):
            return "INV_" + relation

        def get_all_paths(self, source, target, cutoff):
            '''
            得到source和target之间的所有的路径。
            :param source: 起始节点
            :param target: 终止节点
            :param cutoff: 是否启用cutoff
            :return: 返回两个列表,一个是使用节点表示的path,另一个是使用边表示的path。如下所示:
            [ [e1, e2], [e1, e3, e2]], [[r1], [r2, r3] ]
            '''
            if source == target:
                return [], []
            paths_of_entities = []
            paths_of_relations = []
            #遍历源和目的之间所有的path
            for path in nx.all_simple_paths(self.G, source, target, cutoff):
                #可以将path直接添加到paths_of_entities中
                paths_of_entities.append(path)
                relations_path = []
                #对上面每个path取出relation,并组成新的path,然后添加到paths_of_relations中
                for i in range(0, len(path) - 1):
                    relation = self.G[path[i]][path[i + 1]]['relation']
                    relations_path.append(relation)
                paths_of_relations.append(relations_path)
            return paths_of_entities, paths_of_relations

        def get_candidate_neighbors(self, node, num_hops=2, avoid_high_degree_nodes=True):
            '''
            得到与一个节点相关的周边跳数在num_hops之内的所有节点
            :param node:要寻找的节点
            :param num_hops:跳数
            :param avoid_high_degree_nodes:是否去除高度节点
            :return:返回得到的所有节点
            '''
            result = set([])
            q = [node]
            visited = set([node])
            dist = {node: 0}
            while len(q) > 0:
                u = q.pop(0)
                result.add(u)
                for nbr in self.G.neighbors(u):
                    if nbr in self.high_degree_nodes and avoid_high_degree_nodes:
                        continue
                    if nbr not in visited:
                        visited.add(nbr)
                        dist[nbr] = dist[u] + 1
                        if dist[nbr] <= num_hops:
                            q.append(nbr)
            result.remove(node)
  • 6
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值