简单LSTM代码讲解

仅供本人参考,错了概不负责

part1

图源:https://www.zhihu.com/question/41949741/answer/309529532

我们在使用tf.nn.rnn_cell.BasicLSTMCell时,有一个要自己设置的参数 num_units,先讲讲这玩意是啥?

这四个小黄块,有一定了解的同学都知道[ht-1, Xt]输入后,经过四个黄块和St-1,又得到了htSt,所以必然[ht-1, Xt]经过黄块后,维度和原ht-1一样。这个num_units就是ht-1维度,那几个小黄块就是线性映射后再结激活函数。

查看BasicLSTMCell源码:

def build(self, inputs_shape):
    if inputs_shape[-1] is None:
      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                       % inputs_shape)
    input_depth = inputs_shape[-1]
    h_depth = self._num_units
    self._kernel = self.add_variable(
        _WEIGHTS_VARIABLE_NAME,
        shape=[input_depth + h_depth, 4 * self._num_units])
    self._bias = self.add_variable(
        _BIAS_VARIABLE_NAME,
        shape=[4 * self._num_units],
        initializer=init_ops.zeros_initializer(dtype=self.dtype))

build函数中初始化了[input_depth + h_depth, 4 * self._num_units]形状的变量,

输入:其中input_depth代表Xt输入的维度,h_depth也就是_num_units代表ht-1的维度;
输出:4*self._num_units为4个小黄块的维度W

并且源码并没有定义任何时间步长有关的参数,说明cell参数在不同time_step都是共享的。

part2

知道了tf.nn.rnn_cell.BasicLSTMCell是个什么东西之后,我们来讲讲:

搭建LSTM

创建cell之后有至少两种方式创建rnn

  • tf.nn.dynamic_rnn
batch_size = 5
time_step = 7
depth = 30
num_units = 20
inputs = tf.Variable(tf.random_normal([batch_size, time_step, depth])) 
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
outputs, output_state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值