仅供本人参考,错了概不负责
part1
图源:https://www.zhihu.com/question/41949741/answer/309529532
我们在使用tf.nn.rnn_cell.BasicLSTMCell
时,有一个要自己设置的参数 num_units
,先讲讲这玩意是啥?
这四个小黄块,有一定了解的同学都知道[ht-1, Xt]
输入后,经过四个黄块和St-1
,又得到了ht
和St
,所以必然[ht-1, Xt]
经过黄块后,维度和原ht-1
一样。这个num_units
就是ht-1
维度,那几个小黄块就是线性映射后再结激活函数。
查看BasicLSTMCell
源码:
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
input_depth = inputs_shape[-1]
h_depth = self._num_units
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + h_depth, 4 * self._num_units])
self._bias = self.add_variable(
_BIAS_VARIABLE_NAME,
shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype))
build函数中初始化了[input_depth + h_depth, 4 * self._num_units]
形状的变量,
输入:其中
input_depth
代表Xt
输入的维度,h_depth
也就是_num_units
代表ht-1
的维度;
输出:4*self._num_units
为4个小黄块的维度W
并且源码并没有定义任何时间步长有关的参数,说明cell
参数在不同time_step都是共享的。
part2
知道了tf.nn.rnn_cell.BasicLSTMCell
是个什么东西之后,我们来讲讲:
搭建LSTM
创建cell
之后有至少两种方式创建rnn
tf.nn.dynamic_rnn
batch_size = 5
time_step = 7
depth = 30
num_units = 20
inputs = tf.Variable(tf.random_normal([batch_size, time_step, depth]))
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
outputs, output_state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf