TensorFlow中RNN样例代码详解

关于RNN的理论部分已经在上一篇文章中讲过了,本文主要讲解RNN在TensorFlow中的实现。与theano不同,TensorFlow在一个更加抽象的层次上实现了RNN单元,所以调用tensorflow的API来实现RNN是比较容易的。这里先介绍TensorFlow中与RNN相关的几个比较常用的函数,

(1)cell = tf.nn.rnn_cell.BasicLSTMCell(num_units, forget_bias, input_size, state_is_tuple, activation)
     num_units: int, The number of units in the LSTM cell(就是指cell中隐藏层神经元的个数);
     forget_bias: float, The bias added to forget gates (添加到“forget gates”的偏置,这里的“forget gates”指lstm网络中的component);
     input_size: Deprecated and unused(这个参数以后会被废弃掉,就不用考虑了);
     state_is_tuple: 为真表示,状态值是(c_state, m_state)构成的元组,比如每一个time step有K层,那么state结构为((c0, m0), (c1, m1), …, (ck, mk));
     activation: cell中的激励函数;
   注:这个函数用于生成RNN网络的最基本的组成单元,这个类对象中还有一个比较重要的method,call(self, inputs, state, scope=None),它确定
     了在forward propagation过程中,调用BasicLSTMCell对象时的输入输出参数。

(2) cell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
     cells: list of RNNCells that will be composed in this order(根据cells列表中的LSTMCell生成MultiRNNCell的基本组成单元,这里的MultiRNNCell是
       指每一时刻的输出由多层LSTMCell级联而成。显然,列表中的每个LSTMCell可以含有不同的权重参数);
     state_is_tuple: 同上;

(3) state = tf.nn.rnn_cell.MultiRNNCell.zero_state(batch_size, dtype)
     batch_size: 训练块的大小;
     dtype: 指定待返回的state变量的数据类型;
   注:这个函数用于返回全0的state tensor。state tensor的尺寸与层数、hidden units num、batch size有关系,前面两个在定义cell对象时已经指定过
     了,故这里要指定batch_size参数。

在Github上有RNN的TensorFlow官方源代码,主要包括了两个文件,一个是reader.py,另外一个是ptb_word_lm.py。本篇就先来学习一下大牛们提供的源代码,因为代码比较长,这里主要对理解上可能有困难的地方进行解析,希望能对大家有所帮助。

reader.py文件中的子函数

在NLP领域中,自然语言模型是比较经典的应用,在训练RNN模型前,需要把输入数据文件进行预处理,即先设定词库大小vocabulary_size,再根据训练库中单词出现的频数,找到出现次数最多的前vocabulary_size个单词,并把他们映射到0,、、、,vocabulary_size-1,而其他出现频数较少的单词,均设置成“unknown”,索引设置为vocabulary_size。通常情况下,训练数据包含了很多段语句,每段语句的长度可以不一样(用列表和array对象存储矩阵数据时,矩阵中元素的长度可以不一致,所以语料库的存储不存在问题)。当模型训练过程结束时,所学到的模型参数,就是使得训练库中所有的sentence出现概率都非常大时的参数解。值得一提的是,TF仅支持定长输入的RNN(theano中的scan函数支持不定长输入的RNN,但在实际应用中,通常都是提前给inputs加个padding改成定长的训练语料库,因为这样做会使训练速度更快)。

def ptb_raw_data(data_path=None):
  train_path = os.path.join(data_path, "ptb.train.txt") #定义文件路径
  valid_path = os.path.join(data_path, "ptb.valid.txt")
  test_path = os.path.join(data_path, "ptb.test.txt")
    #_build_vocab函数对字典对象,先按value(频数)降序,频数相同的单词再按key(单词)升序。函数返回的是字典对象, 
    # 函数返回的是字典对象,key为单词,value为对应的唯一的编号
  word_to_id = _build_vocab(train_path)
    # _file_to_word_ids函数,用于把文件中的内容转换为索引列表。在转换过程中,若文件中的某个单词不在word_to_id查询字典中,
    # 则不进行转换。返回list对象,list中的每一个元素均为int型数据,代表单词编号
  train_data = _file_to_word_ids(train_path, word_to_id)
  valid_dat
  • 3
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值