1.LSTM的结构(本文啥也没讲)
LSTM的介绍就不多讲了,直接附上链接:
LSTM网络结构
https://www.cnblogs.com/mfryf/p/7904017.html 中文版本
http://colah.github.io/posts/2015-08-Understanding-LSTMs/ 英文版本
2.LSTM学中个人思考过的问题
(1)ht和Ct维度相同?
维度是相同的,因为h(t)=o(t)⊙tanh(C(t)),两者点乘所以维度必然相同,且维度由tensorflow代码BasicLSTMCell(num_units = ??)中的num_units参数确定。
(2)因此Wf、Wi、Wc、Wo参数维度相同?
这几个矩阵的维度是相同的,考虑到h与x的合并后。以上矩阵的参数维度为(num_units,input_depth + h_depth)
(3)不定长序列中,往往通过最后时刻的输出ht决定分类,定长序列可以通过所有的h1至ht构造全连接层然后通过softmax分类?
不定长序列是不能够构造全连接所有h状态的,因为序列的长度不定。(其实不定长序列也可以“全连接”
h
h
h状态,只是不是所有的
h
t
h_t
ht状态直接与输出乘以一个大矩阵连接,而是所有不同时刻的
h
t
h_t
ht均共享权值矩阵)
(4)所有不同时刻的Wf、Wi、Wc、Wo参数共享?
共享,否则不定长序列中无法存储以上参数。
(5)正向传播的表达式?
通过代码验证,第一个时刻的序列是与tensorflow的生成的ht和Ct结果相同,多个时刻则不同???查看源码《rnn_cell_impl.py》发现tensorflow在遗忘门时有个小的偏差叠加,即需要叠加遗忘数值forget_bias(为一个实数)ft = sigmoid( xh.dot(Wf) + Bf+forget_bias),这样在前向推导过程中,所有序列时刻的参数都已经和tensorflow生成的状态c、h相等。
xh = np.column_stack((x,h_pre))
ft = sigmoid(xh.dot(Wf) + Bf+forget_bias)
it = sigmoid(xh.dot(Wi)+Bi)
ot = sigmoid(xh.dot(Wo)+Bo)
ct_ = tanh(xh.dot(Wc)+Bc)
ct = np.multiply(ft,c_pre)+np.multiply(it,ct_)
ht = np.multiply(ot,tanh(ct))
获得LSTM内部各矩阵以及偏置参数以及LSTM网络状态C和h输出详见链接:
https://blog.csdn.net/koibiki/article/details/83116596
https://blog.csdn.net/qq_35203425/article/details/79572514
https://blog.csdn.net/zhylhy520/article/details/82631736
(6).由问题3从而可以构造损失函数,反向传播推导?
显然若仅有最后时刻的ht输出构造的损失函数和由所有时刻构造的损失函数在反向传播的推导过程中各参数的表达式是不相同的。比如当为所有时刻构造的全连接softmax作为损失函数时,ht的偏导数由两部分构成,一部分是全连接,另一部分是ht+1的偏导数向ht做偏导数。所以不同的损失函数对于反向传播的参数推导是决定性(看了好几个网页似乎没有把损失函数交代清楚直接推导的反向传播,因为softmax有个很重要的结论∂L/∂z = (y^(t)−y(t))。
3.反向传播part1(公式轰炸)
本文所有向量全为列向量
当然想所有公式能够从源头就开始造,但是确实每个细节都给出推导过程太费时间了,比如公式(1)结论很简单,但实际上是
L
L
L先对
a
a
a求偏导然后再对
z
z
z求偏导,所以如果你有疑惑可能需要针对某个特定的公式进行查询。
以最后时刻的
h
τ
h_{\tau~}
hτ 构造softmax函数。
3.1softmax层反向传播
3.2 知识点hadamard积的微分
3.3 LSTM结构图
如开篇所示可以通过连接了解到具体的结构,总结一句话就是你中有我我中有你,然后对于反向传播的推导来说就是个不小的挑战。
我将结构图转换为如下所示(将“用于极简”的说明反向传播),实际上当用于反向传播时,所有的箭头的方向都是需要反过来看。
3.4最后时刻LSTM单元内的 h , c h,c h,c反向传播
又来一叠小菜(单元内的非线性函数的导数):
最后时刻
τ
~\tau~
τ 相关的变量仅有
h
τ
h_{\tau~}
hτ ,则
∂
L
/
∂
h
τ
\partial L/\partial h_{\tau~}
∂L/∂hτ 即公式(3)
然后求
∂
L
/
∂
c
τ
\partial L/\partial c_{\tau~}
∂L/∂cτ :
3.4节到底干了啥?我们把他们丢到图里面去,心情会稍微好一点:
就是求得了图中(3)式、(9)式。那(10)~(14)又在干嘛呢?且继续看下去
3.5 求得递归时刻的反向表达式
3.5.1先求对 h t − 1 h_{t-1} ht−1的偏导数
不绕道
c
t
c_t
ct的表达式
∂
h
t
/
→
∂
h
t
−
1
\partial h_t \stackrel{\rightarrow}/ \partial h_{t-1}
∂ht/→∂ht−1 ,相当于仅从
o
t
o_t
ot路径走,其表达式为(10),顺便把(11)(12)一起办了。
这3个公式其实展示的正是下图结构的反向传递。
其实(10)(12)式是可以合并的,合并后的表达式(10_12):
3.5.2再求对 c t − 1 c_{t-1} ct−1的偏导数
对
c
t
−
1
c_{t-1}
ct−1的反向传播有两条路径,路径1是直接从
c
t
c_{t}
ct过来的
∂
L
/
→
∂
c
t
−
1
\partial L \stackrel{\rightarrow}/ \partial c_{t-1}
∂L/→∂ct−1 ,以及从
h
t
−
1
h_{t-1}
ht−1过来的路径2:
∂
L
/
⇒
∂
c
t
−
1
\partial L \stackrel{\Rightarrow~}/ \partial c_{t-1}
∂L/⇒ ∂ct−1
3.5.2节就是解决下图中的方向传播:
以上部分的反向传播总结
即使拆开看,也还是并不简单,再回过头来理一理,传播图片再次杀出。
从损失函数开始倒带,先通过表达式(3)求得对
h
τ
h_{ \tau~}
hτ 的偏导数,然后又根据(9)式求得对
c
τ
c_{ \tau}
cτ的偏导数。而这解决了最后时刻
τ
\tau
τ的变量的偏导数
∂
L
/
∂
h
τ
\partial L/\partial h_{\tau~}
∂L/∂hτ 、
∂
L
/
∂
c
τ
\partial L/\partial c_{\tau~}
∂L/∂cτ :,顺便把当前时刻
τ
\tau
τ中对
V
V
V、
b
b
b的偏导数求出来。
然后反向传播递归,现以倒数第二个时刻
τ
−
1
{\tau-1}
τ−1来具体说明如何使用反向传递。
已知条件
∂
L
/
∂
h
τ
\partial L/\partial h_{\tau~}
∂L/∂hτ 、
∂
L
/
∂
c
τ
\partial L/\partial c_{\tau~}
∂L/∂cτ 。根据(10_12)式可以求得
∂
L
/
∂
h
t
−
1
\partial L / \partial h_{t-1}
∂L/∂ht−1,仔细观察(10_12)式确实只含有变量
∂
L
/
∂
h
t
\partial L/\partial h_{t~}
∂L/∂ht 、
∂
L
/
∂
c
t
\partial L/\partial c_{t~}
∂L/∂ct 。
然后求(13_14)式求得
∂
L
/
∂
c
t
−
1
\partial L / \partial c_{t-1}
∂L/∂ct−1 ,同样表达式(13_14)中仅含有已知的
∂
L
/
∂
c
t
\partial L/\partial c_{t~}
∂L/∂ct 以及
∂
L
/
∂
h
t
−
1
\partial L / \partial h_{t-1}
∂L/∂ht−1。
再多看一眼,其实(9)(11)(14)式仅仅是时刻不同,但是偏导数的形式是一样的。
或许应该思考一下为什么要单独将最后一个时刻的
∂
L
/
∂
h
t
\partial L/\partial h_{t~}
∂L/∂ht 、
∂
L
/
∂
c
t
\partial L/\partial c_{t~}
∂L/∂ct 单独罗列出来,关于最后时刻对
h
τ
h_{\tau~}
hτ 的偏导数和往前时刻的
∂
L
/
∂
h
t
\partial L/\partial h_{t~}
∂L/∂ht 的不同是显而易见的。对于
∂
L
/
∂
c
τ
\partial L/\partial c_{\tau~}
∂L/∂cτ ,由(9)式直接可得,但是往前时刻的
∂
L
/
∂
c
t
\partial L/\partial c_{t~}
∂L/∂ct ,由两条路径组成,分别是式(13)(14),所以不能像最后时刻
τ
\tau
τ那样可以仅由(14)式得到偏导数
∂
L
/
∂
c
t
\partial L/\partial c_{t~}
∂L/∂ct 。
后续
代码链接:LSTM反向传播代码实现(通过tensorflow和自编写代码实现)
这还仅仅是时刻递归,还没有真正得到您所想要的参数矩阵。后面再补……
有需要文中所有公式打印方便阅读的话,请发邮件494431025@qq.com