本文学习的代码来自于知乎文章:https://zhuanlan.zhihu.com/p/29679742
代码地址为:https://link.zhihu.com/?target=https%3A//github.com/domluna/memn2n
数据地址为:http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz
论文地址为:https://arxiv.org/abs/1503.08895
1 任务
解决QA场景。所使用的数据为bAbi。
将这些句子和Question作为模型输入进行建模,希望模型可以学习出这种推理模式。
以task1为例:
两句话后面跟一个问句,并且给出相应答案。答案后面的数字意味着该问题与哪一行相关。然后15行组成一个故事,也就是说这15行内的数据都是相关的,后面的15个组成另外一组数据。所以memory_size是10(15行中有10行是数据,5行是问题)。另外每个句子的组大长度是7。所以处理完之后的数据应该时15*7的矩阵。而且每15行数据会被处理成5组训练样本。第一组是前两行数据加问题和答案,第二个是前四行数据家问题和答案,这样继续下去。也就是说后面的问题是依据前面所有的数据回答的。
2 数据处理
下面是加载任务数据的主函数
def load_task(data_dir, task_id, only_supporting=False):
'''Load the nth task. There are 20 tasks in total.
Returns a tuple containing the training and testing data for the task.
'''
assert task_id > 0 and task_id < 21
files = os.listdir(data_dir)
files = [os.path.join(data_dir, f) for f in files]
s = 'qa{}_'.format(task_id)
#获取对应task的训练数据和测试数据
train_file = [f for f in files if s in f and 'train' in f][0]
#对于任务一获得的是qa1_single-supporting-fact_test.txt文件名
test_file = [f for f in files if s in f and 'test' in f][0]
train_data = get_stories(train_file, only_supporting)
test_data = get_stories(test_file, only_supporting)
return train_data, test_data
下面是获取故事,给定文件名读取文件,获得故事。
def get_stories(f, only_supporting=False):
'''Given a file name, read the file, retrieve the stories, and then convert the sentences into a single story.
If max_length is supplied, any stories longer than max_length tokens will be discarded.
'''
with open(f) as f:
return parse_stories(f.readlines(), only_supporting=only_supporting)
下面具体如何解析得到每个故事
def parse_stories(lines, only_supporting=False):
'''Parse stories provided in the bAbI tasks format
If only_supporting is true, only the sentences that support the answer are kept.
'''
data = []
story = []
'''
'''
for line in lines:
#将每一行转换成小写,分割成id和句子
#如:2 Mary journeyed to the bathroom.拆分成2和Mary journeyed to the bathroom.
line = str.lower(line)
nid, line = line.split(' ', 1)
nid = int(nid)
#nid是每一行的序号(1~15之间),如果等于1则说明是一个新故事的开始,需要将story数组重置。
if nid == 1:
story = []
'''
2 Mary journeyed to the bathroom.
3 Where is John? hallway 1
'''
if '\t' in line: # 如果有\t,则说明是问题行
q, a, supporting = line.split('\t')#supporting 表示通过哪一句推断出来的
q = tokenize(q)
#a = tokenize(a)