关于RNN的理论部分已经在上一篇文章中讲过了,本文主要讲解RNN在TensorFlow中的实现。与theano不同,TensorFlow在一个更加抽象的层次上实现了RNN单元,所以调用tensorflow的API来实现RNN是比较容易的。这里先介绍TensorFlow中与RNN相关的几个比较常用的函数,
(1)cell = tf.nn.rnn_cell.BasicLSTMCell(num_units, forget_bias, input_size, state_is_tuple, activation)
num_units: int, The number of units in the LSTM cell(就是指cell中隐藏层神经元的个数);
forget_bias: float, The bias added to forget gates (添加到“forget gates”的偏置,这里的“forget gates”指lstm网络中的component);
input_size: Deprecated and unused(这个参数以后会被废弃掉,就不用考虑了);
state_is_tuple: 为真表示,状态值是(c_state, m_state)构成的元组,比如每一个time step有K层,那么state结构为((c0, m0), (c1, m1), …, (ck, mk));
activation: cell中的激励函数;
注:这个函数用于生成RNN网络的最基本的组成单元,这个类对象中还有一个比较重要的method,call(self, inputs, state, scope=None),它确定
了在forward propagation过程中,调用BasicLSTMCell对象时的输入输出参数。
(2) cell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
cells: list of RNNCells that will be composed in this order(根据cells列表中的LSTMCell生成MultiRNNCell的基本组成单元,这里的MultiRNNCell是
指每一时刻的输出由多层LSTMCell级联而成。显然,列表中的每个LSTMCell可以含有不同的权重参数);
state_is_tuple: 同上;
(3) state = tf.nn.rnn_cell.MultiRNNCell.zero_state(batch_size, dtype)
batch_size: 训练块的大小;
dtype: 指定待返回的state变量的数据类型;
注:这个函数用于返回全0的state tensor。state tensor的尺寸与层数、hidden units num、batch size有关系,前面两个在定义cell对象时已经指定过
了,故这里要指定batch_size参数。
在Github上有RNN的TensorFlow官方源代码,主要包括了两个文件,一个是reader.py,另外一个是ptb_word_lm.py。本篇就先来学习一下大牛们提供的源代码,因为代码比较长,这里主要对理解上可能有困难的地方进行解析,希望能对大家有所帮助。
reader.py文件中的子函数
在NLP领域中,自然语言模型是比较经典的应用,在训练RNN模型前,需要把输入数据文件进行预处理,即先设定词库大小vocabulary_size,再根据训练库中单词出现的频数,找到出现次数最多的前vocabulary_size个单词,并把他们映射到0,、、、,vocabulary_size-1,而其他出现频数较少的单词,均设置成“unknown”,索引设置为vocabulary_size。通常情况下,训练数据包含了很多段语句,每段语句的长度可以不一样(用列表和array对象存储矩阵数据时,矩阵中元素的长度可以不一致,所以语料库的存储不存在问题)。当模型训练过程结束时,所学到的模型参数,就是使得训练库中所有的sentence出现概率都非常大时的参数解。值得一提的是,TF仅支持定长输入的RNN(theano中的scan函数支持不定长输入的RNN,但在实际应用中,通常都是提前给inputs加个padding改成定长的训练语料库,因为这样做会使训练速度更快)。
def ptb_raw_data(data_path=None):
train_path = os.path.join(data_path, "ptb.train.txt") #定义文件路径
valid_path = os.path.join(data_path, "ptb.valid.txt")
test_path = os.path.join(data_path, "ptb.test.txt")
#_build_vocab函数对字典对象,先按value(频数)降序,频数相同的单词再按key(单词)升序。函数返回的是字典对象,
# 函数返回的是字典对象,key为单词,value为对应的唯一的编号
word_to_id = _build_vocab(train_path)
# _file_to_word_ids函数,用于把文件中的内容转换为索引列表。在转换过程中,若文件中的某个单词不在word_to_id查询字典中,
# 则不进行转换。返回list对象,list中的每一个元素均为int型数据,代表单词编号
train_data = _file_to_word_ids(train_path, word_to_id)
valid_dat