最近阅读了pytorch中lstm的源代码,发现其中有很多值得学习的地方。
首先查看pytorch当中相应的定义
\begin{array}{ll} \\
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
h_t = o_t \odot \tanh(c_t) \\
\end{array}
对应公式:
圈1: f t = σ ( W i f x t + b i f + W h f h t − 1 + b h f ) f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) ft=σ(Wifxt+bif+Whfht−1+bhf)
圈2: i t = σ ( W i i x t + b i i + W h i h t − 1 + b h i ) i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) it=σ(Wiixt+bii+Whiht−1+bhi)
圈3: g t = tanh ( W i g x t + b i g + W h g h t − 1 + b h g ) g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) gt=tanh(Wigxt+big+Whght−1+bhg)
圈4: o t = σ ( W i o x t + b i o + W h o h t − 1 + b h o ) o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) ot=σ(Wioxt+bio+Whoht−1+bho)
圈5: c t = f t ⊙ c t − 1 + i t ⊙ g t c_t = f_t \odot c_{t-1} + i_t \odot g_t ct=ft⊙ct−1+it⊙gt
圈6: h t = o t ⊙ tanh ( c t ) h_t = o_t \odot \tanh(c_t)