先上经典的LSTM结构
1、首先 tf.nn.rnn_cell.BasicLSTMCell(num_units=n)中的参数num_units指的是什么?
上图中一个浅绿色大框框起来的我们暂时叫一个LSTM_Cell,我们可以看到一个LSTM_Cell中有四个基本的神经网络Cell(即四个黄色的小框),每个小框可以说存放的是个向量,且四个框中向量长度相同,这个相同的向量长度便是num_units
2、LSTM中的参数到底有多少个?
这里假设输入向量(即xt)的长度是75,num_units=125
我们详细看张量在LSTM_Cell中是如何流动的
1)遗忘门(最左边的黄色框)
它接收的是上一时刻的隐藏状态和当前时刻的输入,经过矩阵得到另外一个向量才经过。
所以这个过程的参数个数是125*(125+75)+125
2)输入门(第二个黄色框)
这个过程决定保留输入的哪些信息,由两步构成,首先是sigma层决定我们要更新哪些值,接下来tanh层对输入做一次加工(有点归一化的味道),二者得到的结果相乘后去更新LSTM_Cell的状态。这里经过了相当于两次矩阵相乘再加偏置的运算,所以这个过程的参数个数是2*(125*(125+75)+125)
3)输出门(最后一个黄色框)
可以看到隐藏状态也有两部分组成。此过程的参数个数也是125*(125+75)+125
4)LSTM_Cell的状态更新
这一部分,没有新的参数
总之,整个过程的参数个数是:(125*(125+75)+125)*4
更一般地,若输入的长度=m,隐藏层的长度=n,则一个LSTM层的参数个数是
(n*(n+m)+n)*4