深度学习算法——LSTM(长短期记忆网络)

参考教材:《动手学习深度学习》

一、模型概述

        长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。解决这一问题的最早方法之一是长短期存储器(long short‐term memory,LSTM)(Hochreiter and Schmidhuber, 1997)。它有许多与门控循环单元一样的属性。有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些,却比门控循环单元早诞生了近20年。

1、门控记忆元

        可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。长短期记忆网络引入了记忆元(memory cell), 或简称为单元(cell)。有些文献认为记忆元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。为了控制记忆元,我们需要许多门。其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理,这种设计的动机与门控循环单元相同,能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。

输入门、忘记门和输出门

        就如在门控循环单元中一样,当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,如图所示。它们由三个具有sigmoid激活函数的全连接层处理,以计算输入门、遗忘门和输出门的值。因此,这三个门的值都在(0, 1)的范围内。

         我们来细化一下长短期记忆网络的数学表达。假设有h个隐藏单元,批量大小为n,输入数为d。因此,输入为Xt ∈ R n×d,前一时间步的隐状态为Ht−1 ∈ R n×h。相应地,时间步t的门被定义如下:输入门是It ∈ R n×h, 遗忘门是Ft ∈ R n×h,输出门是Ot ∈ R n×h。

它们的计算方法如下: 

候选记忆元

        由于还没有指定各种门的操作,所以先介绍候选记忆元(candidate memory cell)C˜ t ∈ R n×h。它的计算与 上面描述的三个门的计算类似,但是使用tanh函数作为激活函数,函数的值范围为(−1, 1)。下面导出在时间 步t处的方程:  

 记忆元

        在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。类似地,在长短期记忆网络中,也有两个门用于这样的目的:输入门It控制采用多少来自C˜ t的新数据,而遗忘门Ft控制保留多少过去的记忆元Ct−1 ∈ R n×h的 内容。使用按元素乘法,得出:

        如果遗忘门始终为1且输入门始终为0,则过去的记忆元Ct−1 将随时间被保存并传递到当前时间步。引入这种设计是为了缓解梯度消失问题,并更好地捕获序列中的长距离依赖关系。 这样我们就得到了计算记忆元的流程图,如图。

隐状态

        最后,我们需要定义如何计算隐状态 Ht ∈ R n×h,这就是输出门发挥作用的地方。在长短期记忆网络中,它仅仅是记忆元的tanh的门控版本。这就确保了Ht的值始终在区间(−1, 1)内:

        只要输出门接近1,我们就能够有效地将所有记忆信息传递给预测部分,而对于输出门接近0,我们只保留记忆元内的所有信息,而不需要更新隐状态。 下图提供了数据流的图形化演示。  

二、模型实现 

初始化模型参数

接下来,我们需要定义和初始化模型参数。如前所述,超参数num_hiddens定义隐藏单元的数量。我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size
    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01
    def three():
        return (normal((num_inputs, num_hiddens)),
        normal((num_hiddens, num_hiddens)),
        torch.zeros(num_hiddens, device=device))
    W_xi, W_hi, b_i = three() # 输入门参数
    W_xf, W_hf, b_f = three() # 遗忘门参数
    W_xo, W_ho, b_o = three() # 输出门参数
    W_xc, W_hc, b_c = three() # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
    b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

定义模型

在初始化函数中,长短期记忆网络的隐状态需要返回一个额外的记忆元,单元的值为0,形状为(批量大小, 隐藏单元数)。因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
    torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样:提供三个门和一个额外的记忆元。请注意,只有隐状态才会传递到输出层,而记忆元Ct不直接参与输出计算。

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
    W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
    outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过引入的RNNModelScratch类来训练一个长短期记忆网络。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 17736.0 tokens/sec on cuda:0

time traveller for so it will leong go it we melenot ir cove i s

traveller care be can so i ngrecpely as along the time dime.

长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。然而,由于序列的长距离依赖性,训练长短期记忆网络和其他序列模型 (例如门控循环单元)的成本是相当高的。

三、LSTM在数学建模的应用

        LSTM(长短期记忆)模型是一种特殊的循环神经网络(RNN),在处理和预测时间序列数据方面表现出色。在数学建模比赛中,LSTM模型可以应用于多种问题,以下是一些主要的应用场景:

  1. 时间序列预测

    • 股票价格预测:利用历史股票价格数据,预测未来的股票走势。
    • 经济指标预测:比如GDP、通货膨胀率、失业率等宏观经济数据的预测。
    • 能源消耗预测:根据历史能源使用数据预测未来的能源需求,有助于能源管理和规划。
  2. 自然语言处理(NLP)

    • 文本分类:在比赛中,如果涉及到对大量文本的分类,如情感分析、新闻分类等,LSTM可以捕捉文本中的序列信息。
    • 文本生成:生成文章、诗歌等连续文本。
  3. 交通流量预测

    • 预测不同时间段、不同路段的车流量,对于交通管理和城市规划具有重要意义。
  4. 环境监测

    • 空气质量预测:利用历史和实时监测数据预测空气质量,为环境治理提供决策支持。
    • 水位预测:预测河流、湖泊的水位变化,对于防洪减灾具有重要作用。
  5. 金融市场分析

    • 汇率预测:预测不同货币对的汇率变动。
    • 市场趋势分析:分析市场趋势,预测市场动向。
  6. 生物信息学

    • 基因表达分析:分析时间序列的基因表达数据,理解基因如何响应外部刺激。
    • 蛋白质结构预测:通过学习氨基酸序列的时序关系预测蛋白质结构。

        LSTM模型可以解决的问题通常具有以下特点:

  • 序列依赖性:问题的数据点之间存在时间上的依赖关系。
  • 长期依赖问题:需要记忆的数据间隔可能很长。
  • 非线性和复杂性:问题涉及的数据可能具有复杂的非线性关系。

        在使用LSTM模型时,需要注意以下几点:

  • 数据预处理:时间序列数据通常需要被标准化或归一化,并且可能需要时间窗口的划分来创建输入序列。
  • 超参数调优:LSTM模型有许多超参数,如学习率、隐藏层大小、迭代次数等,需要通过实验来优化。
  • 过拟合问题:LSTM模型可能会出现过拟合,需要采用正则化技术或dropout等方法来减轻。
  • 17
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值