tensorflow--rnn的注意点

参考:https://blog.csdn.net/mydear_11000/article/details/52414342

placeholder

要训练的模型是language model,也就是一个word,预测最有可能的下一个word,因此input和output是同型的,并且,placeholder只存储一个batch的data,input接收的是word在vocabulary中对应的index,后续会将index转成dense embedding,每次接收一个seq长度的words,input_shape=[batch_size,seq_length]

定义cell


这其中的每个小长方形就表示一个cell,每个cell又是一个略复杂的结构。一个cell包含多个hidden units(表示隐层神经元个数)。因此,tensorflow中定义一个cell结构时需要提供一个参数就是hidden_units_size

tips:每个time_step都复用一个cell

state_is_tuple:true时,接收和返回的是c_state,m_state的2-tuples;false时,则将c_state,m_state拼接在一起

DropoutWrapper

dropout是一种非常efficient的regularization方法,对于rnn的部分不进行dropout,也就是说从t-1时候的状态传递到t时刻进行计算时,这个中间不进行memory的dropout,仅在同一t时刻中,多层cell之间传递信息的时候进行dropout。


从t-2时刻的输入传入第一层cell,这个过程有dropout,从该时刻的第一层cell传到t-1,t,t+1的第一层cell这个中间都不进行dropout。再从t+2时刻的第一层cell向同一时刻内后续的cell传递是,这之间又有dropout了。

因此,我们在定义完cell之后,在cell外部包裹上dropout,这个类叫做DropoutWrapper,这样我们的cell就有dropout功能。

DropoutWrapper有input_keep_prob和output_kepp_prob,也就是说裹上这个DropoutWrapper之后,如果我希望是input传入这个cell时dropout掉一部分信息的话,就设置input_keep_prob,那么传入到cell的就是部分input;如果我希望这个cell的output部分作为下一层cell的input的话,就设置output_keep_prob。

Stack MultiCell

我们现在定义了一个lstm cell,这个cell仅是整个图中的一个小长方形,我们希望整个网络能更deep的话,应该stack多个这样的lstm cell,tensorflow给我们提供了MultiRNNCell(只有这一个类)

tensorflow并不是简单的堆叠了多个single cell,而是将这次cell stack之后当成了一个完整的独立cell,每个小cell的中间状态还是被保存下来了,按n_tuple存储,但输出output只用最后那个cell的输出。

这样就定义好了每个时刻t的整体cell,接下来只要每个时刻传入不同的输入,再在时间上展开,就能得到多个时间上unroll graph。

initial_state

初始化隐状态为0

接下来要给我们的multi lstm cell进行状态初始化,state_size是我们在定义MultiRNNCellde的时候就设置好的,只是我们的输入input_shape=[batch_size,num_steps],我们刚刚定义好的cell会依次接收num_steps个输入然后产生最后的state(n-tupe,n表示堆叠的层数),但是一个batch内有batch_size这样的seq,因此需要[batch_size,state_size]来存储整个batch每个seq的状态。

RNN循环起来

tf.nn.dynamic_rnn

outputs,states=tf.nn.dynamic_rnn(multicell,x,initial_state=_initial_state)

state是final state,如果有n layer,则是final state也有n个元素,对应每一层的state

tf.nn.rnn

。。。。。。。







评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值