跟李沐写LSTM(详细注释)

本文介绍了LSTM(LongShort-TermMemory)循环神经网络如何通过门控机制处理长期依赖问题和梯度问题,以及在自然语言处理和语音识别中的广泛应用。详细讲解了LSTM的工作原理和在实际编程中的实现示例。
摘要由CSDN通过智能技术生成

LSTM(Long Short-Term Memory)是一种特殊的循环神经网络(RNN),它能够学习长期依赖关系。LSTM由Hochreiter和Schmidhuber在1997年提出,旨在解决传统RNN在处理长序列数据时遇到的梯度消失或梯度爆炸问题。

LSTM的核心思想

LSTM的核心在于其内部结构的设计,它引入了“门”(Gate)的概念,这些门可以控制信息的流动,包括信息的保存、更新和遗忘。LSTM的基本单元包括三个门:输入门(Input Gate)、遗忘门(Forget Gate)和输出门(Output Gate)。

  • 输入门:决定当前输入信息中有多少应该被添加到单元状态中。
  • 遗忘门:决定哪些信息应该从单元状态中被遗忘或保留。
  • 输出门:基于当前的单元状态和输入,决定最终的输出。

LSTM的工作流程

  1. 遗忘阶段:遗忘门查看上一个隐藏状态和当前输入,决定保留和丢弃哪些信息。
  2. 输入阶段:输入门决定新的输入信息中哪些部分应该被记住,并结合遗忘门的结果,更新单元状态。
  3. 输出阶段:输出门基于更新后的单元状态和当前输入,决定最终的输出,这个输出将作为下一个时间步的隐藏状态。

LSTM的优势

  • 长期依赖问题:LSTM通过精心设计的门控机制,能够有效地捕捉长期依赖关系,这是传统RNN难以做到的。
  • 梯度问题:由于其结构,LSTM在训练过程中不容易出现梯度消失或梯度爆炸的问题,使得网络可以学习到更深层次的模式。
  • 灵活性:LSTM可以应用于多种序列数据任务,如语言模型、机器翻译、语音识别等。

LSTM的应用

LSTM在自然语言处理(NLP)、语音识别、时间序列分析等领域有着广泛的应用。例如,在NLP中,LSTM可以用来建模句子中的词序关系,进行文本生成、情感分析、命名实体识别等任务;在语音识别中,LSTM能够处理音频信号的时间序列特性,提高识别的准确性。

结论

LSTM作为一种强大的序列模型,通过其独特的门控机制解决了传统RNN在处理长序列数据时的难题。它的出现极大地推动了序列建模技术的发展,并在多个领域中取得了显著的成果。随着深度学习技术的不断进步,LSTM仍然是一个非常有价值和有潜力的研究领域。

import torch
from torch import nn
from d2l import torch as d2l


class LSTMScratch(d2l.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        #              将其封装为nn.Parameter对象。用于初始化权重矩阵。
        init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
        #         - 一个形状为(num_inputs, num_hiddens)的权重矩阵(输入门或输入节点的输入权重)
        #         - 一个形状为(num_hiddens, num_hiddens)的权重矩阵(输入门、忘记门、输出门的隐藏状态权重)
        #         - 一个形状为(num_hiddens,)的全零偏置向量(输入门、忘记门、输出门的偏置项)
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          nn.Parameter(torch.zeros(num_hiddens)))
        # 初始化LSTM单元的权重和偏置参数
        # Input gate
        self.W_xi, self.W_hi, self.b_i = triple()
        # W_xi: 输入门的输入权重矩阵
        # W_hi: 输入门的隐藏状态权重矩阵
        # b_i: 输入门的偏置项
        # Forget gate
        self.W_xf, self.W_hf, self.b_f = triple()
        # W_xf: 忘记门的输入权重矩阵
        # W_hf: 忘记门的隐藏状态权重矩阵
        # b_f: 忘记门的偏置项
        # Output gate
        self.W_xo, self.W_ho, self.b_o = triple()
        # W_xo: 输出门的输入权重矩阵
        # W_ho: 输出门的隐藏状态权重矩阵
        # b_o: 输出门的偏置项

        # Input node (Cell gate or Candidate cell state)
        self.W_xc, self.W_hc, self.b_c = triple()
        # W_xc: 输入节点的输入权重矩阵
        # W_hc: 输入节点的隐藏状态权重矩阵
        # b_c: 输入节点的偏置项


@d2l.add_to_class(LSTMScratch)
def forward(self, inputs, H_C=None):
    if H_C is None:
        # Initial state with shape: (batch_size, num_hiddens)
        H = torch.zeros((inputs.shape[1], self.num_hiddens),
                        device=inputs.device)
        C = torch.zeros((inputs.shape[1], self.num_hiddens),
                        device=inputs.device)
    else:
        H, C = H_C
    outputs = []
    for X in inputs:
        # 输入门I用于控制新记忆单元C_tilde的多少被加入到记忆单元C中。
        # 遗忘门F用于控制上一个时间步的记忆单元C中有多少被遗忘。
        # 输出门O用于控制当前时间步的记忆单元C多少被输出为隐藏状态H。
        I = torch.sigmoid(torch.matmul(X, self.W_xi) +
                          torch.matmul(H, self.W_hi) + self.b_i)
        F = torch.sigmoid(torch.matmul(X, self.W_xf) +
                          torch.matmul(H, self.W_hf) + self.b_f)
        O = torch.sigmoid(torch.matmul(X, self.W_xo) +
                          torch.matmul(H, self.W_ho) + self.b_o)
        C_tilde = torch.tanh(torch.matmul(X, self.W_xc) +
                             torch.matmul(H, self.W_hc) + self.b_c)
        C = F * C + I * C_tilde
        H = O * torch.tanh(C)
        outputs.append(H)
    return outputs, (H, C)


data = d2l.TimeMachine(batch_size=1024, num_steps=32)
lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
model.predict('it has', 20, data.vocab, d2l.try_gpu())

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值