End-To-End Memory Networks代码研读

本文深入解析了End-To-End Memory Networks的代码实现,主要涉及任务、数据处理和模型构建。该模型应用于问答场景,使用bAbi数据集进行训练。数据处理包括读取和解析故事,生成训练样本。模型构建中,详细介绍了记忆单元的嵌入、权重计算及注意力机制。总结指出,Memory Network的内存并非传统意义的外部存储,而是通过样本更新并依赖注意力机制读取。
摘要由CSDN通过智能技术生成

本文学习的代码来自于知乎文章:https://zhuanlan.zhihu.com/p/29679742
代码地址为:https://link.zhihu.com/?target=https%3A//github.com/domluna/memn2n
数据地址为:http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz
论文地址为:https://arxiv.org/abs/1503.08895

1 任务

解决QA场景。所使用的数据为bAbi。
https://zhuanlan.zhihu.com/p/29679742
将这些句子和Question作为模型输入进行建模,希望模型可以学习出这种推理模式。
以task1为例:
在这里插入图片描述
两句话后面跟一个问句,并且给出相应答案。答案后面的数字意味着该问题与哪一行相关。然后15行组成一个故事,也就是说这15行内的数据都是相关的,后面的15个组成另外一组数据。所以memory_size是10(15行中有10行是数据,5行是问题)。另外每个句子的组大长度是7。所以处理完之后的数据应该时15*7的矩阵。而且每15行数据会被处理成5组训练样本。第一组是前两行数据加问题和答案,第二个是前四行数据家问题和答案,这样继续下去。也就是说后面的问题是依据前面所有的数据回答的

2 数据处理

下面是加载任务数据的主函数

def load_task(data_dir, task_id, only_supporting=False):
    '''Load the nth task. There are 20 tasks in total.

    Returns a tuple containing the training and testing data for the task.
    '''
    assert task_id > 0 and task_id < 21

    files = os.listdir(data_dir)
    files = [os.path.join(data_dir, f) for f in files]
    s = 'qa{}_'.format(task_id)
    #获取对应task的训练数据和测试数据
    train_file = [f for f in files if s in f and 'train' in f][0]
    #对于任务一获得的是qa1_single-supporting-fact_test.txt文件名
    test_file = [f for f in files if s in f and 'test' in f][0]
    train_data = get_stories(train_file, only_supporting)
    test_data = get_stories(test_file, only_supporting)
    return train_data, test_data

下面是获取故事,给定文件名读取文件,获得故事。

 def get_stories(f, only_supporting=False):
       '''Given a file name, read the file, retrieve the stories, and then convert the sentences into a single story.
       If max_length is supplied, any stories longer than max_length tokens will be discarded.
       '''
       with open(f) as f:
           return parse_stories(f.readlines(), only_supporting=only_supporting)

下面具体如何解析得到每个故事

def parse_stories(lines, only_supporting=False):
       '''Parse stories provided in the bAbI tasks format
       If only_supporting is true, only the sentences that support the answer are kept.
       '''
       data = []
       story = []
       '''

		'''
       for line in lines:
       	   #将每一行转换成小写,分割成id和句子
       	   #如:2 Mary journeyed to the bathroom.拆分成2和Mary journeyed to the bathroom.
           line = str.lower(line)
           nid, line = line.split(' ', 1)
           nid = int(nid)
           #nid是每一行的序号(1~15之间),如果等于1则说明是一个新故事的开始,需要将story数组重置。
           if nid == 1:
               story = []
           '''
			2 Mary journeyed to the bathroom.
			3 Where is John? 	hallway	1
		   '''
           if '\t' in line: # 如果有\t,则说明是问题行
               q, a, supporting = line.split('\t')#supporting 表示通过哪一句推断出来的
               q = tokenize(q)
               #a = tokenize(a)
               
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值