LSTM网络初识

LSTM网络

介绍

Long Short Term Memory networks(以下简称LSTMs),一种特殊的RNN网络,该网络设计出来是为了解决长依赖问题。该网络由 Hochreiter & Schmidhuber (1997)引入,并有许多人对其进行了改进和普及。他们的工作被用来解决了各种各样的问题,直到目前还被广泛应用。

所有循环神经网络都具有神经网络的重复模块链的形式。 在标准的RNN中,该重复模块将具有非常简单的结构,例如单个tanh层。标准的RNN网络如下图所示

在这里插入图片描述
LSTMs也具有这种链式结构,但是它的重复单元不同于标准RNN网络里的单元只有一个网络层,它的内部有四个网络层。LSTMs的结构如下图所示。

在这里插入图片描述

三道门

LSTM有三道门来控制细胞状态,分别为忘记门、输入门和输出门。
LSTM第一步是决定需要丢弃那些信息,而通过忘记门里的sigmoid单元来处理。他通过查看上一步的h[t-1]和这一步的x[t]来输出一个0-1的向量,该向量表示细胞状态c[t-1]中那些信息保留或丢弃多少。忘记门如下图所示。
在这里插入图片描述
下一步是决定跟细胞状态添加那些信息,又分两个步骤,首先,利用h[t-1]和x[t]通过输出门的操作决定更新那些信息,然后利用h[t-1]和x[t]通过一个tanh层得到一个新的候选信息,这些信息可能会被更新到细胞信息中。
在这里插入图片描述
下面将更新旧的细胞信息C[t-1]变为新的细胞信息C[t]。更新的规则是通过忘记门选择记住旧细胞的一部分,通过输入门选择添加的候选细胞信息的一部分得到细胞信息C[t]
在这里插入图片描述
更新完成后细胞状态后需要根据输入的h[t-1]和x[t]来判断输出细胞的那些状态特征,这里需要经过一个成为输入门的sigmoid层得到判断条件,然后将细胞状态经过tanh层得到一个【-1,1】之间的向量,该向量与输出门得到的判断条件相乘就得到了最终RNN单元的输出。
在这里插入图片描述

代码实现
class Rnn(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
        super(Rnn, self).__init__()
        self.n_layer = n_layer
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, n_classes)

    def forward(self, x):
        out, (h_n, c_n) = self.lstm(x)
        # 此时可以从out中获得最终输出的状态h
        # x = out[:, -1, :]
        x = h_n[-1, :, :]
        x = self.classifier(x)
        return x
LSTM表示

lstm的block相当于rnn的隐层,不过这个隐层叫block或者cell之类的。

输入是x,c[t-1], h[t-1], 输出是cell_output = h, state = (c[t],h[t])。注意cell_output包含所有步的h,而state中的h^t只包含最近一步的h。另外c和h的维度应该都是hidden_size(unit_size)。

看到有多种不同等价的图表示:
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值