LSTM:长短期记忆网络

理论

在这里插入图片描述
在这里插入图片描述

LSTM的核心是细胞状态,用贯穿细胞的水平线表示
在这里插入图片描述

1.计算遗忘门

在这里插入图片描述
决定细胞状态需要舍弃哪部分无用信息

f t = σ g ( W f x t + U f h t − 1 + b f ) f_t = \sigma{_g} (W_f x_t+U_f h_{t-1}+b_f) ft=σg(Wfxt+Ufht1+bf)

2.计算输入门

在这里插入图片描述
决定细胞状态需要添加哪些有用信息

i t = σ g ( W i x t + U i h t − 1 + b i ) i_t = \sigma{_g}(W_i x_t+U_i h_{t-1}+b_i) it=σg(Wixt+Uiht1+bi)

3.计算候选细胞状态

c ~ t = σ c ( W c x t + U c h t − 1 + b c ) \widetilde{c}_t=\sigma{_c}(W_cx_t+U_ch_{t-1}+b_c) c t=σc(Wcxt+Ucht1+bc)

4.更新细胞状态

在这里插入图片描述

c t = f t ∘ c t − 1 + i t ∘ c ~ t c_t=f_t \circ c_{t-1}+i_t \circ \widetilde{c}_t ct=ftct1+itc t

5.计算输出门

控制细胞状态中哪些信息被输出
o t = σ g ( W o x t + U o h t − 1 + b o ) o_t=\sigma{_g}(W_ox_t+U_oh_{t-1}+b_o) ot=σg(Woxt+Uoht1+bo)

6.计算输出隐状态

h t = o t ∘ σ h ( c t ) h_t = o_t \circ \sigma{_h}(c_t) ht=otσh(ct)

实践

从零实现LSTM

class My_LSTM(nn. Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.gates = nn.Linear(input_size + hidden_size, hidden_size * 4)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn. Tanh()
        self.output = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, output_size)
        )
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def forward(self, x):
        batch_size = x.size(0)
        seq_len = x.size(1)
        h, c = (torch.zeros(batch_size, self.hidden_size).to(x.device) for _ in range(2))
        y_list = []
        for i in range(seq_len):
            forget_gate, input_gate, output_gate, candidate_cell = \
                self.gates(torch.cat([x[:, i, :], h], dim=-1)).chunk(4, -1)
            forget_gate, input_gate, output_gate = (self.sigmoid(g)
                                                    for g in (forget_gate, input_gate, output_gate))
            c = forget_gate * c + input_gate * self.tanh(candidate_cell)
            h = output_gate * self.tanh(c)
            y_list.append(self.output(h))
        return torch.stack(y_list, dim=1), (h, c)

Pytorch实现LSTM

参数

在这里插入图片描述

输入

在这里插入图片描述

输出

在这里插入图片描述

lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=1,batch_first=True).to(device)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值