解读(Solving Math Word Problems with Multi-Encoders and Multi-DEcoders)的代码(数据处理部分)

导入必要的包

import os,json,time,re,copy,random
from collections import Counter
import numpy as np
import torch
import torch.nn as nn

加载数据的函数

def load_raw_data(filename):
    data=[]
    with open(filename,encoding='utf-8') as f:
        lines=f.readlines()#lines是一个列表,每一个元素是文件中的一行,文件中的每7行组成一条训练样本
        
    json_string=''
    #每7行是一个样本
    for line_id,line_str in enumerate(lines):
        json_string+=line_str
        line_id+=1
        if line_id%7==0:
            example=json.loads(json_string)#json.loads可以将字典形式的字符串转换成一个字典
            #example是一个字典,key值有'id','original_text','segmented_text','equation','ans'
            if '千米/小时' in example['equation']:
                example['equation']=example['equation'][:-5]#有些等式中含有(千米/小时)这个单位,把这个单位去掉
            data.append(example)
            json_string=''
    return data

在这里插入图片描述

构造输入数据和输出数据的形式

这个函数是数据处理部分的核心函数

def transfer_num(data):
    '''
    将数据集中的每一个样本对应的文本问题中的数字替换成NUM
    '''
    #正则表达式中: +表示出现一次或多次,*表示出现零次或多次
    #\d*\(\d+/\d+\)\d*  这个正则是为了匹配行如 (3/5)、2(3/5)、2(3/5)12 这类的数字(也就是带有括号的分数)
    #\d+\.\d+%? 这个正则是为了匹配行如 3.5 3.5% 这类的数字(也就是小数或者带有百分号的小数)
    #\d+%? 这个正则是为了匹配整数以及3%这类的带有百分号的整数
    
    pattern=re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?")#pattern用来匹配问题文本中的所有数字
    
    pairs,generate_nums,generate_nums_dict=[],[],{}
    copy_nums=0#copy_nums用来记录数据集中的所有问题中,哪一个问题中出现的数字次数最多,copy_nums用来记录这个次数
    #copy_nums的数值影响着decoder端的词汇空间
    
    for example in data:
        #example行如:{'id': '5001','original_text': '某电视机厂原来每天生产116台电视机,现在每天生产的台数是原来的12倍,现在每天能生产多少台电视机?',
        #'segmented_text': '某 电视机厂 原来 每天 生产 116 台 电视机 , 现在 每天 生产 的 台数 是 原来 的 12 倍 , 现在 每天 能 生产 多少 台 电视机 ?',
        #'equation': 'x=116*12','ans': '1392'}
        idx=example['id']
        nums=[]#nums用来记录根据pattern匹配出的问题文本中的所有数字
        input_seq=[]#input_seq用来将问题文本中的所有数字替换成NUM
        seg=example['segmented_text'].strip().split(' ')
        #seg行如: ['某', '电视机厂', '原来', '每天', '生产', '116', '台', '电视机', ',', '现在', '每天', '生产', '的', '台数', '是', '原来', '的', '12', '倍', ',', '现在', '每天', '能', '生产', '多少', '台', '电视机', '?']
        equation=example['equation'][2:]
        #equations的形式行如: x=(25+14)/(1-(1/5)-(1/5));x=(11-1)*2;x=116*12等
        for token in seg:
            pos=re.search(pattern=pattern,string=token)#如果token是数字,那么pos返回的不是None
            if pos and pos.start()==0:
                nums.append(token[pos.start():pos.end()])
                input_seq.append('NUM')#将所有数字替换成NUM
                if pos.end()<len(token):
                    #说明此时的token不仅仅含有数字,eg: 116千克,那么input_seq中要添加千克这个单词
                    input_seq.append(token[pos.end():])
            elif token!='':
                #此时的token中没有数字
                input_seq.append(token)
        if copy_nums<len(nums):
            copy_nums=len(nums)#copy_nums用来记录所有问题中出现数字次数最多的那个问题出现的数字的次数
            
        nums_fraction=[]#nums_fraction用来记录这个问题中出现的行如(2/5)这种带括号的分数数字
        for num in nums:
            if re.search('\d*\(\d+/\d+\)\d*',num):
                nums_fraction.append(num)#num行如 5(2/5) (2/5) 5(2/5)5 这种,
        nums_fraction=sorted(nums_fraction,key=lambda x:len(x),reverse=True)#将nums_fraction中的带括号的分数数字按照长度排序
        #实验表明,排序或者不排序一点关系没有
        def seg_and_tag(equation):
            '''
            seg_and_tag函数的作用是将equation,也就是表达式中的字符分割开,例如:equation='(25+14)/(1-(1/5)-(1/5))'
            那么返回的表达式应该是['(', '25', '+', '14', ')', '/', '(', '1', '-', '(1/5)', '-', '(1/5)', ')']
            同时要将各个数字替换成Ni,i代表这个数字在问题文本中出现的顺序
            这也是为什么前面要用nums_fraction专门保存带括号的分数,这样才能使得整个括号和分数看成一个整体
            '''
            res=[]
            for num in nums_fraction:
                #如果nums_fraction是空列表,也就是说当前问题没有带括号的分数,那么这个for循环自然不会执行
                if num in equation:
                    #从equation中找到这个带括号的分数的位置
                    p_start=equation.find(num)
                    p_end=p_start+len(num)
                    if p_start>0:
                        #以上面的equation为例子,显然此时num等于(1/5),所以p_start>0,此时我们需要处理(25+14)/(1-
                        res+=seg_and_tag(equation[:p_start])
                    if nums.count(num)==1:
                        #也就是说这个数字仅在问题文本中出现过一次,那么此时就可以用Ni代替这个数字,
                        #i表示的是这个数字在文本中出现的顺序
                        res.append('N'+str(nums.index(num)))
                    else:
                        #说明这个数字在问题文本中出现了多次,那么此时直接记录这个数字,而不用Ni替代
                        res.append(num)
                    if p_end<len(equation):
                        res+=seg_and_tag(equation[p_end:])#递归右边的部分
                    return res
            #现在已经将这类括号带分数的数字处理完毕,接下来处理整数、小数、百分数
            number_position=re.search(pattern='\d+\.\d+%?|\d+%?',string=equation)
            if number_position:
                p_start=number_position.start()
                p_end=number_position.end()
                if p_start>0:
                    #类似的,递归左边
                    res+=seg_and_tag(equation[:p_start])
                number=equation[p_start:p_end]
                if nums.count(number)==1:
                    res.append('N'+str(nums.index(number)))
                else:
                    res.append(number)
                if p_end<len(equation):
                    res+=seg_and_tag(equation[p_end:])
                return res
            #上面的代码是用来处理数字的,如:带有括号的分数、小数、整数、百分数等
            #下面的for循环处理equation中的 括号和+-/*
            for rest_op in equation:
                #rest_op要么是括号(),要么是+-/*
                res.append(rest_op)
            return res
        
        output_seq=seg_and_tag(equation=equation)#output_seq就是decoder端要生成的表达式标签
        
        for token in output_seq:
            if token[0].isdigit() and token not in generate_nums and token not in nums:
                #说明此时这是一个数字,并且这个数字没有出现在问题中,这类数字包括1或者3.14这种常数
                generate_nums.append(token)
                generate_nums_dict[token]=1
            if token in generate_nums and token not in nums:
                generate_nums_dict[token]+=1
        
        num_pos=[]#num_pos用来记录每一个数字的位置将equation
        for i,j in enumerate(input_seq):
            #input_seq是将问题中的所有数字替换成NUM后的变量
            if j=='NUM':
                num_pos.append(i)
        assert len(nums)==len(num_pos)
        #nums记录的是每一个数字,num_pos记录的是每一个数字的位置
        pairs.append((idx,input_seq,output_seq,nums,num_pos))
        
    #结束for循环后,我们就已经处理了所有的问题,接下来统计数据集中频繁出现的常数
    temp_g=[]#用来记录数据集中频繁出现的常数,比如3.14
    for g in generate_nums:
        if generate_nums_dict[g]>=5:
            temp_g.append(g)
    return pairs,temp_g,copy_nums

我们通过几幅图片来看
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

提取词性标注和句法分析特征

这里用到了哈工大的pyltp工具包,我们通过两幅图来看
在这里插入图片描述
需要注意的是Root默认占用0,所以其它单词的索引是需要id-1的,这也是为什么源码中有arc.head-1这行代码,不过由于版本问题,此时的arc是一个元祖tuple,不过含义是一样的。

在这里插入图片描述
也就是说postagger用来标注每一个单词的词性(名词、动词等),parser用来提取整个句子中各个单词的依存句法关系,关于上面的具体的细节以及ATT,SBV,WP这都是什么玩意,不在详细介绍。

中缀转前缀或者后缀表达式

关于原理,请参考中缀表达式转后缀表达式

首先设置两个栈,操作数栈和运算符栈

中缀转前缀

  • 从右至左扫描表达式
  • 如果是右括号,例如:)]},那么直接push进运算符栈
  • 如果是左括号,例如:([{,那么由于我们是从右边扫描的,此时栈中一定有对应的右括号.做法就是一直弹出栈中的运算符(弹出的运算符放到操作数栈中),直到遇到对应的右括号,然后去掉这一对括号
  • 如果是运算符,那么此时要比较优先级,如果栈顶运算符的优先级大,那么就弹出栈顶的运算符(弹出的运算符放到操作数栈中),这个操作是一直执行到栈顶运算符的优先级小于当前运算符的优先级,然后将当前运算符push到运算符栈中
  • 如果是操作数,直接push到操作数栈中
  • 重复上面的步骤,直到扫描完整个中缀表达式
  • 将运算符栈中的所有元素push到操作数栈中,返回操作数栈
def from_infix_to_prefix(expression):
    operator_stack=[]#运算符栈
    operand_stack=[]#操作数栈
    operator_priority={'+':0,'-':0,'*':1,'/':1,'^':2}
    expression=deepcopy(expression)#deepcopy是深拷贝
    expression.reverse()#转前缀的过程是从右至左扫描
    for e in expression:
        if e in [')',']']:
            #当遇到右括号时,直接进栈
            operator_stack.append(e)
        elif e =='(':
            #弹出栈中的运算符,直到遇到)为止
            temp=operator_stack.pop()
            while temp!=')':
                operand_stack.append(temp)
                temp=operator_stack.pop()
        elif e=='[':
            #弹出栈中的运算符,直到遇到]为止
            temp=operator_stack.pop()
            while temp!=']':
                operand_stack.append(temp)
                temp=operator_stack.pop()
        elif e in operator_priority:
            #此时是运算符,需要比较优先级,当栈顶运算符的优先级大于e的优先级时,就一直弹栈
            #不过需要注意的是,如果栈顶是右括号,那么就不能再弹了,因为右括号要等到左括号来了才能弹栈
            while len(operator_stack)>0 and operator_stack[-1] not in [')',']'] and operator_priority[e]<operator_priority[operator_stack[-1]]:
                operand_stack.append(operator_stack.pop())
            operator_stack.append(e)
        else:
            #说明此时的e是操作数
            operand_stack.append(e)
    #将运算符栈中的剩余运算符全部弹出到操作数栈中
    while len(operator_stack)>0:
        operand_stack.append(operator_stack.pop())
    operand_stack.reverse()
    return operand_stack

在这里插入图片描述

中缀转后缀

思路是一样的,只不过有几个不同点:

  • 从左至右扫描中缀表达式
  • 由于是从左至右,所以遇到左括号直接压栈,遇到右括号弹栈
def from_infix_to_postfix(expression):
    operator_stack=[]
    operand_stack=[]
    expression=deepcopy(expression)
    operator_priority={'+':0,'-':0,'*':1,'/':1,'^':2}
    for e in expression:
        if e in ['(','[']:
            operator_stack.append(e)
        elif e ==')':
            temp=operator_stack.pop()
            while temp!='(':
                operand_stack.append(temp)
                temp=operator_stack.pop()
        elif e ==']':
            temp=operator_stack.pop()
            while temp!='[':
                operand_stack.append(temp)
                temp=operator_stack.pop()
        elif e in operator_priority:
            while len(operator_stack)>0 and operator_stack[-1] not in ['(','['] and operator_priority[e]<operator_priority[operator_stack[-1]]:
                operand_stack.append(operator_stack.pop())
            operator_stack.append(e)
        else:
            operand_stack.append(e)
    while len(operator_stack)>0:
        operand_stack.append(operator_stack.pop())
    return operand_stack

在这里插入图片描述

生成5折交叉验证的训练测试数据集

def generate_train_test(math23k_file):
    data=load_raw_data(math23k_file)#data的每一个元素是一个dict,字段有:id,original_text,segmented_text,equation,ans
    pairs,generate_nums,copy_nums=transfer_num(data)
    #pairs是将data的每一个数据里面的segmented_text中的数字转换成NUM,将equation中的数字转换成Ni,其中i
    #代表这个数字在问题中出现的顺序,pairs还有两个元素,分别记录问题对应的所有数字和数字的位置
    
    pre_temp_pairs=[]
    for p in pairs:
        #p[0]是id,p[1]是行如['新世纪', '百货', '开展', '“', '庆', 'NUM', '一', '”', '促销', '活动', ', '再', '降', 'NUM', '?'],
        #这样的问题
        postags=postagger.post(p[1])#也就是标注问题中的每一个单词的词性
        arcs=parser.parse(p[1],postags)#提取整个句子的句法
        parse_tree=[arc[0]-1 for arc in arcs]#其中arc是一个元祖(id,relation),id代表的就是当前这个单词与哪一个单词有关联,
        #id表示的就是那个单词在整个句子中的索引,但是由于ROOT这个单词默认占据0,所以单词的实际位置需要-1
        #relation表示的就是句法关系
        
        pre_temp_pairs.append((p[0],p[1],postags,parse_tree,
                               from_infix_to_prefix(p[2]),from_infix_to_postfix(p[2]),p[3],p[4]))
        #其中p[3]和p[4]分别是nums和nums_pos,也就是这个问题中所有的数字和数字的位置
        #p[2]就是中缀表达式,现在已经转换成前缀和后缀了
    pairs=pre_temp_pairs
    #接下来构造5折交叉验证的数据集
    fold_size=int(len(pairs)*0.2)#fold_size也就是每一折的测试集合大小,在math23k上约等于4632
    fold_pairs=[]
    for split_fold in range(4):
        fold_start=fold_size*split_fold
        fold_end=fold_size*(split_fold+1)
        fold_pairs.append(pairs[fold_start:fold_end])
    #split_fold==0,1,2,3
    #fold_pairs==[pairs[0:4632],pairs[4632:9264],pairs[9264:13896],pairs[13896:18528]]
    fold_pairs.append(pairs[fold_size*4:])#fold_pairs==[pairs[0:4632],pairs[4632:9264],pairs[9264:13896],pairs[13896:18528],pairs[18528:23162]]
    
    for fold in range(5):
        pairs_tested=[]
        pairs_trained=[]
        for fold_t in range(5):
            if fold_t==fold:
                #当fold==0时,就用fold_pairs[0]作为测试集,其它四个作为训练集
                pairs_tested+=fold_pairs[fold_t]
            else:
                pairs_trained+=fold_pairs[fold_t]
        with open("data/train"+str(fold)+".json",'w') as f:
            json.dump(pairs_trained,f,ensure_ascii=False,indent=4)
        with open("data/test"+str(fold)+".json","w") as f:
            json.dump(pairs_tested,f,ensure_ascii=False,indent=4)

在这里插入图片描述

train_example=pairs_trained[10]
print("example id : ",train_example[0])
print("example input seq : ",train_example[1])
print("example question pos(pos指的是词性) : ",train_example[2])
print("example syntatic parser(句法分析) : ",train_example[3])
print("example prefix expression : ",train_example[4])
print("example postfix expression : ",train_example[5])
print("example question nums : ",train_example[6])
print("example question nums_pos : ",train_example[7])

在这里插入图片描述

我们已经清楚了pairs_trained中每一个数据的结构

构造encoder和decoder的词典类

PAD_token=0#默认pad位置用0填充
class Lang:
    def __init__(self):
        self.word2index={}#词到id的转换字典
        self.word2count={}#词到词频的转换字典
        self.index2word=[]
        self.n_words=0
        self.num_start=0
    
    def add_sen_to_vocab(self,sentence):
        #传进来的sentence有多种形式   第一种是问题文本,行如:['要', '修', '一段', '长', 'NUM', '千米', '的', '路', ',', '第一天', '修', '了', 'NUM', '千米', ',', '第', '二', '天', '修', '了', '余下', '的', 'NUM', ',', '还', '剩下', '多少', '千米', '没有', '修', '完', '?']
        #第二种是句子的标注词性,行如['v', 'v', 'm', 'a', 'ws', 'q', 'u', 'n', 'wp', 'nt', 'v', 'u', 'ws', 'q', 'wp', 'm', 'm', 'q', 'v', 'u', 'v', 'u', 'ws', 'wp', 'd', 'v', 'r', 'q', 'd', 'v', 'v', 'wp']
        #这是因为论文有两个encoder,之前的论文只有一个encoder,只需要问题文本作为输入
        #第三种是前缀表达式,行如['-', '-', 'N0', '*', '-', 'N0', 'N1', 'N2', 'N1']
        #第四种是后缀表达式,行如['N0', 'N0', 'N1', '-', 'N2', '*', '-', 'N1', '-']
        for word in sentence:
            if re.search(pattern='N\d+|NUM|\d+',string=word):
                continue#数字和特殊字符NUM不作为encoder端的词汇
            if word not in self.index2word:
                self.word2index[word]=self.n_words
                self.word2count[word]=1
                self.index2word.append(word)
                self.n_words+=1
            else:
                self.word2count[word]+=1
    def trim(self,min_count):
        '''
        根据min_count去除词典中的单词,缩小词典的空间
        '''
        keep_words=[]
        for word,freq in self.word2count.items():
            if freq>=min_count:
                #词频高的词保留
                keep_words.append(word)
        self.word2index={}
        self.word2count={}
        self.index2word=[]
        self.n_words=0
        
        for word in keep_words:
            self.word2index[word]=self.n_words
            self.index2word.append(word)
            self.n_words+=1
    
    def build_input_lang(self,trim_min_count):
        if trim_min_count>0:
            self.trim(min_count=trim_min_count)
            self.index2word=['PAD','NUM','UNK']+self.index2word#因为删除了一些单词后,在训练集中自然会出现一些没有见过的单词
        else:
            self.index2word=['PAD','NUM']+self.index2word
        #重置word2index,因为要考虑PAD和NUM以及UNK等特殊字符
        self.word2index={word:index for index,word in enumerate(self.index2word)}
    
    def build_input_lang_for_pos(self):
        #对于词性标注的输入,没有NUM需要考虑,而且不需要删除不常见单词
        self.index2word=['PAD','UNK']+self.index2word#需要注意的是,调用这个函数的对象一定是词性标注输入的对象
        self.n_words=len(self.index2word)
        self.word2index={word:index for index,word in enumerate(self.index2word)}
    
    def build_output_lang(self,generate_nums,copy_nums):
        '''
        generate_nums代表的是常数,如: 1,3.14
        copy_nums代表的是出现数字次数最多的那个问题出现的数字次数,copy_nums决定了decoder端最多可以预测多少个不同数字
        '''
        self.index2word+=['PAD','EOS']+generate_nums+['N'+str(i) for i in range(copy_nums)]+['SOS','UNK']
        self.n_words=len(self.index2word)
        self.word2index={word:index for index,word in enumerate(self.index2word)}
    def build_output_lang_for_tree(self,generate_nums,copy_nums):
        '''
        树形结构的decoder和sequence结构的decoder是不同的,因为tree结构不是序列式的生成表达式,所以不考虑PAD和EOS,SOS等
        '''
        self.num_start=len(self.index2word)
        self.index2word+=generate_nums+['N'+str(i) for i in range(copy_nums)]+['UNK']
        self.n_words=len(self.index2word)
        self.word2index={word:index for index,word in enumerate(self.index2word)}
        

验证一下

input1_lang = Lang()
input2_lang = Lang()
output1_lang = Lang()
output2_lang = Lang()

for pair in pairs_trained:
    if pair[-1]:
        input1_lang.add_sen_to_vocab(pair[1])#pair[1]是问题文本
        input2_lang.add_sen_to_vocab(pair[2])#pair[2]是问题句子的词性
        output1_lang.add_sen_to_vocab(pair[4])#pair[4]是前缀表达式
        output2_lang.add_sen_to_vocab(pair[5])#pair[5]是后缀表达式
        
trim_min_count=5
input1_lang.build_input_lang(trim_min_count)
input2_lang.build_input_lang_for_pos()
output1_lang.build_output_lang_for_tree(generate_nums, copy_nums)
output2_lang.build_output_lang(generate_nums, copy_nums)

在这里插入图片描述
在这里插入图片描述

将句子转为id序列

def indexes_from_sentence(lang,sentence,tree=False):
    '''
    根据lang中的word2index将sentence中的每一个token转为对应的id
    这里面的sentence不一定是句子,也可能是词性标注序列,或者输出的前缀后缀表达式
    '''
    res=[]
    unk_token=lang.word2index['UNK']
    for token in sentence:
        if len(token)==0:
            continue
        res.append(lang.word2index.get(token,unk_token))
    if 'EOS' in lang.index2word and not tree:
        #输出端有两个decoder,其中一个是sequence式结构,另一个是tree结构
        #sequence结构中需要有'EOS'
        res.append(lang.word2index['EOS'])
    return res

在这里插入图片描述
在这里插入图片描述

def texts_from_sentence(lang, sentence, tree=False):
    '''
    函数的目的是将sentence中出现的词汇如果不在lang.word2index中,那么就换成UNK
    '''
    res = []
    for word in sentence:
        if len(word) == 0:
            continue
        if word in lang.word2index:
            res.append(word)
        else:
            res.append("UNK")
    if "EOS" in lang.index2word and not tree:
        res.append(lang.word2index["EOS"])
    return res

def num_list_processed(num_list):
    '''
    num_list代表的是一个问题中所有的数字
    函数的目的是将num_list中的数字进一步换算成对应的值,同时将百分号等数字替换成对应的小数
    将分数也同样计算成对应的小数
    '''
    st = []
    for p in num_list:
        pos1 = re.search("\d+\(", p)
        pos2 = re.search("\)\d+", p)
        if pos1:
            st.append(eval(p[pos1.start(): pos1.end() - 1] + "+" + p[pos1.end() - 1:]))
        elif pos2:
            st.append(eval(p[:pos2.start() + 1] + "+" + p[pos2.start() + 1: pos2.end()]))
        elif p[-1] == "%":
            st.append(float(p[:-1]) / 100)
        else:
            st.append(eval(p))
    return st

def num_order_processed(num_list):
    '''
    由于论文中提出要比较一个问题中所有数字的大小,所以这个函数的作用就是用整数来表达一个数字在当前这个问题中的所有
    数字的大小,数值的大小代表的是这个数字大于多少个数字
    '''
    num_order = []
    num_array = np.asarray(num_list)
    for num in num_array:
        num_order.append(sum(num>num_array)+1)
    
    return num_order

在这里插入图片描述

准备传入模型的数据

def prepare_data(pairs_trained,pairs_tested,trim_min_count,generate_nums,copy_nums):
    '''
    pairs[0]-->id,问题样本id
    pairs[1]-->input seq,问题文本
    pairs[2]-->pos,问题单词的词性标注
    pairs[3]-->parser,句法分析的结果
    pairs[4]-->prefix expression
    pairs[5]-->postfix expression
    pairs[6]-->nums
    pairs[7]-->nums_pos
    '''
    input1_lang = Lang()
    input2_lang = Lang()
    output1_lang = Lang()
    output2_lang = Lang()
    train_pairs = []
    test_pairs = []

    print("Indexing words...")
    for pair in pairs_trained:
        if pair[-1]:
            input1_lang.add_sen_to_vocab(pair[1])
            input2_lang.add_sen_to_vocab(pair[2])
            output1_lang.add_sen_to_vocab(pair[4])
            output2_lang.add_sen_to_vocab(pair[5])
    
    input1_lang.build_input_lang(trim_min_count)
    input2_lang.build_input_lang_for_pos()
    output1_lang.build_output_lang_for_tree(generate_nums, copy_nums)
    output2_lang.build_output_lang(generate_nums, copy_nums)

    for pair in pairs_trained:
        num_stack = []
        for word in pair[4]:
            #pair[4]是前缀表达式,行如['/', '*', 'N1', 'N2', '5']
            temp_num = []
            flag_not = True
            #output1_lang是树形结构decoder的词空间
            if word not in output1_lang.index2word:
                #这种情况是因为前缀表达式中出现了数字,而我们知道,数字是不作为词空间中的元素的
                #表达式中按理说所有的数字都已经被转为对应的Ni了,出现数字的原因是这个数字在问题中出现了多次
                flag_not = False
                for i, j in enumerate(pair[6]):
                    #pair[6]是nums,也就是每一个数字,行如 ['5', '16.5', '2.1', '5']
                    if j == word:
                        temp_num.append(i)#temp==[0,3],temp记录的是表达式中出现的重复的数字在nums中的位置

            if not flag_not and len(temp_num) != 0:
                num_stack.append(temp_num)
            if not flag_not and len(temp_num) == 0:
                num_stack.append([_ for _ in range(len(pair[6]))])

        #num_stack.reverse()#实验表明,这行代码没有用
        input1_cell = indexes_from_sentence(input1_lang, pair[1])#pair[1] is input_seq
        texts_cell = texts_from_sentence(input1_lang, pair[1])
        input2_cell = indexes_from_sentence(input2_lang, pair[2])#pair[2] is input seq pos
        output1_cell = indexes_from_sentence(output1_lang, pair[4], True)#pair[4] is prefix_expression, used for tree-decoder
        output2_cell = indexes_from_sentence(output2_lang, pair[5], False)#pair[5] is postfix expression, 
        num_list = num_list_processed(pair[6])#pair[6] is nums
        num_order = num_order_processed(num_list)
        train_pairs.append((pair[0], texts_cell, input1_cell, input2_cell, pair[3], len(input1_cell), 
                            output1_cell, len(output1_cell), output2_cell, len(output2_cell), 
                            pair[6], pair[7], num_stack, num_order))
    print('Indexed %d words in input language, %d words in output1, %d words in output2' % 
          (input1_lang.n_words, output1_lang.n_words, output2_lang.n_words))
    print('Number of training data %d' % (len(train_pairs)))
    for pair in pairs_tested:
        num_stack = []
        for word in pair[4]:
            temp_num = []
            flag_not = True
            if word not in output1_lang.index2word:
                flag_not = False
                for i, j in enumerate(pair[6]):
                    if j == word:
                        temp_num.append(i)

            if not flag_not and len(temp_num) != 0:
                num_stack.append(temp_num)
            if not flag_not and len(temp_num) == 0:
                num_stack.append([_ for _ in range(len(pair[6]))])

        num_stack.reverse()
        input1_cell = indexes_from_sentence(input1_lang, pair[1])
        texts_cell = texts_from_sentence(input1_lang, pair[1])
        input2_cell = indexes_from_sentence(input2_lang, pair[2])
        output1_cell = indexes_from_sentence(output1_lang, pair[4], True)
        output2_cell = indexes_from_sentence(output2_lang, pair[5], False)
        num_list = num_list_processed(pair[6])
        num_order = num_order_processed(num_list)
        test_pairs.append((pair[0], texts_cell, input1_cell, input2_cell, pair[3], len(input1_cell), 
                           output1_cell, len(output1_cell), output2_cell, len(output2_cell), 
                           pair[6], pair[7], num_stack, num_order))
    print('Number of testind data %d' % (len(test_pairs)))
    return input1_lang, input2_lang, output1_lang, output2_lang, train_pairs, test_pairs

input1_lang, input2_lang, output1_lang, output2_lang, train_pairs, test_pairs = prepare_data(pairs_trained, pairs_tested, 5, generate_nums, copy_nums)
train_example=train_pairs[500]
print("example id : ",train_example[0])
print("example input seq (词频少的单词已经被替换成UNK): ",train_example[1])
print("将所有的单词替换成对应的id : ",train_example[2])
print("将标注的词性替换成对应的id : ",train_example[3])
print("句法分析的结构: ",train_example[4])
print("句子长度 : ",train_example[5])
print("将前缀表达式中的运算符替换成对应的id : ",train_example[6])
print("前缀表达式的长度 : ",train_example[7])
print("将后缀表达式中的运算符替换成对应的id : ",train_example[8])
print("后缀表达式的长度(后缀表达式是作为sequence decoder的标签,所以包含EOS,长度要更长一些) : ",train_example[9])
print('这个问题对应的所有的数字 : ',train_example[10])
print('这个问题中数字的位置 : ',train_example[11])
print('这个问题是否包含有重复数字,如果有,重复数字出现的位置 : ',train_example[12])
print('这个问题中所有数字的大小关系 : ',train_example[13])

在这里插入图片描述

构造真正的输入数据

def prepare_train_batch(pairs_to_batch,batch_size):
    '''
    这个函数用来构造输入数据
    对于pairs_to_batch中的每一个元素example,都有14个字段,分别是
    example id;example input seq (词频少的单词已经被替换成UNK);
    example input_seq_id(所有的单词替换成对应的id);example pos_id(将标注的词性替换成对应的id);
    example parse(句法分析的结构);example_length(句子长度);
    example prefix_expression_id(将前缀表达式中的运算符替换成对应的id);prefix_expression length(前缀表达式的长度);
    example postfix_expression_id(将后缀表达式中的运算符替换成对应的id);postfix_expression length(后缀表达式的长度);
    example question nums(这个问题对应的所有的数字);example question nums_pos(这个问题中数字的位置);
    example question num_stack(这个问题是否包含有重复数字,如果有,重复数字出现的位置);example question num_order(这个问题中所有数字的大小关系)
    每一个example有14个字段
    '''
    
    pairs=deepcopy(pairs_to_batch)
    random.shuffle(pairs)#随机打乱训练数据,因为我们要保证各个数据样本之间是相互独立的,满足iid条件
    
    id_batches=[]#存储各个样本的id
    input1_batches=[]#存储各个样本中问题对应的id(将问题文本中的单词转成id)
    input2_batches=[]#存储各个样本中问题的每一个单词对应的词性标注对应的id
    #input1和input2都是sequence encoder的输入
    input_lengths=[]#存储各个样本中问题的长度
    output1_lengths=[]#存储各个样本中问题对应的前缀表达式的长度
    output2_lengths=[]#存储各个样本中问题对应的后缀表达式的长度
    nums_batches=[]#存储各个样本中问题中出现的数字个数,也就是len(nums)
    num_pos_batches=[]#对应的,存储各个样本中问题中出现的数字在问题中的索引
    num_order_batches=[]#存储每一个问题中各个数字之间的大小关系
    num_stack_batches=[]#如果问题中出现了重复数字,记录重复数字在nums中的位置,否则是[]
    num_size_batches=[]
    output1_batches = []
    output2_batches = []
    parse_graph_batches = []#存储句法解析
    
    batches=[]#按照批次来存储数据,每一批数据为一个单词
    num_of_batch=0
    print()
    print('一共有{}个训练数据样本,按照{}为批次大小,所以一共有{}个训练批次'.format(len(pairs),batch_size,len(pairs)//batch_size+1))
    while num_of_batch+batch_size<len(pairs):
        batches.append(pairs[num_of_batch:num_of_batch+batch_size])
        num_of_batch+=batch_size
    batches.append(pairs[num_of_batch:])
    
    for batch in batches:
        #在每一个批次中,按照这个批次的每一个句子的长度排序,句子长的放在前面,这样有助于后面的RNN编码
        batch=sorted(batch,key=lambda example:example[5],reverse=True)#example[5]是句子长度
        input_length=[]
        output1_length=[]
        output2_length=[]
        for id_,input_seq,seq_id,pos_id,parse,seq_len,prefix_id,prefix_len,postfix_id,postfix_len,nums,nums_pos,num_stack,num_order in batch:
            input_length.append(seq_len)
            output1_length.append(prefix_len)
            output2_length.append(postfix_len)
        input_lengths.append(input_length)
        output1_lengths.append(output1_length)
        output2_lengths.append(output2_length)
        input_len_max = input_length[0]#当前这个批次中所有问题长度的最大值
        output1_len_max = max(output1_length)
        output2_len_max = max(output2_length)
        
        id_batch = []
        input1_batch = []
        input2_batch = []
        output1_batch = []
        output2_batch = []
        num_batch = []
        num_stack_batch = []
        num_pos_batch = []
        num_order_batch = []
        num_size_batch = []
        parse_tree_batch = []
        
        for idx,input_seq,seq_id,pos_id,parse,seq_len,prefix_id,prefix_len,postfix_id,postfix_len,num,num_pos,num_stack,num_order in batch:
            id_batch.append(idx)
            seq_id+=[PAD_token for _ in range(input_len_max-seq_len)]#pad
            pos_id+=[PAD_token for _ in range(input_len_max-seq_len)]#pad
            input1_batch.append(seq_id)
            input2_batch.append(pos_id)
            prefix_id+=[PAD_token for _ in range(output1_len_max-prefix_len)]
            postfix_id+=[PAD_token for _ in range(output2_len_max-postfix_len)]
            #表达式同样需要pad
            output1_batch.append(prefix_id)
            output2_batch.append(postfix_id)
            num_batch.append(len(num))#这个问题出现了多少个数字
            num_stack_batch.append(num_stack)#是否有重复数字
            num_pos_batch.append(num_pos)#数字的位置
            num_order_batch.append(num_order)#数字之间的大小关系
            num_size_batch.append(len(num_pos))
            assert len(num)==len(num_pos)
            parse_tree_batch.append(parse)
            
        id_batches.append(id_batch)
        input1_batches.append(input1_batch)
        input2_batches.append(input2_batch)
        output1_batches.append(output1_batch)
        output2_batches.append(output2_batch)
        nums_batches.append(num_batch)
        num_stack_batches.append(num_stack_batch)
        num_pos_batches.append(num_pos_batch)
        num_order_batches.append(num_order_batch)
        num_size_batches.append(num_size_batch)
        
        parse_g=get_parse_graph_batch(input_length, parse_tree_batch)
        assert type(parse_g)==np.ndarray
        assert parse_g.shape==(len(batch),3,input_len_max,input_len_max)
        parse_graph_batches.append(parse_g)
        
    return id_batches, input1_batches, input2_batches, input_lengths, output1_batches, output1_lengths, output2_batches, output2_lengths, \
       nums_batches, num_stack_batches, num_pos_batches, num_order_batches, num_size_batches, parse_graph_batches

在这里插入图片描述

已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页