pytorch lstm源代码解读

最近阅读了pytorch中lstm的源代码,发现其中有很多值得学习的地方。
首先查看pytorch当中相应的定义

        \begin{array}{ll} \\
            i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
            f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
            g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
            o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
            c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
            h_t = o_t \odot \tanh(c_t) \\
        \end{array}

lstm原理图
对应公式:
圈1: f t = σ ( W i f x t + b i f + W h f h t − 1 + b h f ) f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) ft=σ(Wifxt+bif+Whfht1+bhf)
圈2: i t = σ ( W i i x t + b i i + W h i h t − 1 + b h i ) i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) it=σ(Wiixt+bii+Whiht1+bhi)
圈3: g t = tanh ⁡ ( W i g x t + b i g + W h g h t − 1 + b h g ) g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) gt=tanh(Wigxt+big+Whght1+bhg)
圈4: o t = σ ( W i o x t + b i o + W h o h t − 1 + b h o ) o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) ot=σ(Wioxt+bio+Whoht1+bho)
圈5: c t = f t ⊙ c t − 1 + i t ⊙ g t c_t = f_t \odot c_{t-1} + i_t \odot g_t ct=ftct1+itgt
圈6: h t = o t ⊙ tanh ⁡ ( c t ) h_t = o_t \odot \tanh(c_t)

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值