LSTM中tf.nn.dynamic_rnn处理过程详解

在唐宇迪之tensorflow学习笔记项目实战(LSTM情感分析)一文中,链接地址如下https://blog.csdn.net/liushao123456789/article/details/78991581。对于tf.nn.dynamic_rnn处理过程的代码如下,但是每一步缺少细致的解释,本博客旨在帮助小伙伴们详细了解每一的步骤以及为什么要这样做。

lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits)
lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.75)
value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32)

 lstmUnits为神经元的个数,前两行代码比较好理解,第三行代码生成的value和_令我百思不得其解。接着又出现另外几行代码更让我云里雾里。

weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))
bias = tf.Variable(tf.constant(0.1, shape=[numClasses]))
value = tf.transpose(value, [1, 0, 2])
#取最终的结果值
last = tf.gather(value, int(value.get_shape()[0]) - 1)
prediction = (tf.matmul(last, weight) + bias)

看到这里不禁会发问,为什么要对value进行value = tf.transpose(value, [1, 0, 2])这部分操作,然后last = tf.gather(value, int(value.get_shape()[0]) - 1)这一步又有什么作用?带着这些疑问,我通过不停地百度,参考https://blog.csdn.net/qq_35203425/article/details/79572514这篇文章终于得出解答。

首先tf.nn.dynamic_rnn的输出包括outputsstates两部分。在唐宇迪例子中value相当于outputs,我们需要找outputs的最后一个step的输出。对value进行value = tf.transpose(value, [1, 0, 2])操作后得到的shape为[step,batch_size,lstmUnits].而后last = tf.gather(value, int(value.get_shape()[0]) - 1),其中value.get_shape()[0]) - 1找到value经过transpose后的最后一个分片,last = tf.gather(value, int(value.get_shape()[0]) - 1)表示最后一个[batch_size,lstmUnits],也就是lstm最后的输出,这时候weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))的shape为[lstmUnits,numClasses],last的shape为[batch_size,lstmUnits],两者相乘的维度为[batch_size,numClasses],再与偏置向量相加即可得到。真的输出应该是states.h。

笔者认为应该是:states是由(c,h)组成的tuple,大小均为[batch,lstmUnits]。所以如果想用dynamic_rnn得到输出后,只需要最后一次的状态输出,直接调用states.h即可,也可以按照上述进行操作

  • 9
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值