深度学习第40讲:RNN之长短期记忆网络LSTM

     在上一讲中笔者介绍了 RNN 变体之一的门控循环单元 GRU,本节将继续在 GRU 基础之上介绍一种更为通用的记忆网络、也就是大名鼎鼎的 LSTM(Long Short Term Memory)。

640?wx_fmt=jpeg

     我们先来回顾一下上一讲的 GRU 单元。GRU 中有两个门:一个是更新门,一个是重置门:

640?wx_fmt=png

     在 GRU 中我们通过更新门来决定是否要用记忆细胞候选值 ct_hat 来更新 ct。那么 LSTM 相较于 GRU 的变化在哪呢?LSTM 相较于 GRU 最大的特点就是更加复杂了,复杂到一个 LSTM 单元中有三个门控,但是复杂的结果就是相比于 GRU 有着更好的记忆性能。下图是标准 RNN 结构与 LSTM 单元的结构对比:

640?wx_fmt=png

640?wx_fmt=png

     可以看到,每个 LSTM 单元中包含了 4 个交互的网络层,完整的 LSTM 公式表示如下所示:

640?wx_fmt=png

     下面我们根据结构图和公式来逐模块解释 LSTM。

  • 记忆细胞 ct-1 -> ct

     从图中可以看到在 LSTM 单元的最上层有一条贯穿的关于记忆细胞 ct-1 到 ct 的箭头直线。这样贯穿的直线表现记忆信息在网络各层之间保持下去很容易。

640?wx_fmt=png

      然后来看 LSTM 的第一个门控:遗忘门。

  • 遗忘门(forget gate)

     遗忘门的计算公式如下:

640?wx_fmt=png

     所谓遗忘门就是我们要决定从记忆细胞中是否丢弃某些信息,这个过程就是遗忘门要干的事,我们通过一个 sigmoid 函数来进行处理。遗忘门在整个结构中的位置如下图所示:

640?wx_fmt=png

     可以看到,遗忘门接受来自输入 xt 和上一层隐状态 ht-1 的值进行加权计算处理。

  • 更新门(update gate)

     然后就是更新门,我们需要确定什么样的信息能存入细胞状态中。这跟我们在 GRU 中类似,除了计算更新门之外,还需要通过 tanh 计算记忆细胞的候选值 ct_hat。 LSTM 中更新门需要更加细心一点。候选值和更新门的计算公式如下:

640?wx_fmt=png

     更新门在整个结构中的位置如下图所示:

640?wx_fmt=png

     然后,LSTM 结合遗忘门、更新门、上一层记忆细胞值和记忆细胞候选值来共同决定和更新当前细胞状态:

640?wx_fmt=png

     在整个结构中位置如下图所示:

640?wx_fmt=png

  • 输出门(output gate)

     LSTM 提供了单独的输出门。计算公式如下:

640?wx_fmt=png

     输出门的位置如下图所示:

640?wx_fmt=png

     

     以上便是完整的 LSTM 结构。虽然复杂,但经我们逐步解析之后也就基本清晰了。下面我们将上述过程进行简单的实现,写一个最简单的 LSTM 单元:

def rnn_forward(x, a0, parameters):
    """
    Arguments:
    x -- Input data for every time-step, of shape (n_x, m, T_x).
    a0 -- Initial hidden state, of shape (n_a, m)
    parameters -- python dictionary containing:
                        Waa -- Weight matrix multiplying the hidden state, numpy array of shape (n_a, n_a)
                        Wax -- Weight matrix multiplying the input, numpy array of shape (n_a, n_x)
                        Wya -- Weight matrix relating the hidden-state to the output, numpy array of shape (n_y, n_a)
                        ba --  Bias numpy array of shape (n_a, 1)
                        by -- Bias relating the hidden-state to the output, numpy array of shape (n_y, 1)

    Returns:
    a -- Hidden states for every time-step, numpy array of shape (n_a, m, T_x)
    y_pred -- Predictions for every time-step, numpy array of shape (n_y, m, T_x)
    caches -- tuple of values needed for the backward pass, contains (list of caches, x)
    """

    # Initialize "caches" which will contain the list of all caches
    caches = []    
    # Retrieve dimensions from shapes of x and parameters["Wya"]
    n_x, m, T_x = x.shape
    n_y, n_a = parameters["Wya"].shape    
    # initialize "a" and "y" with zeros
    a = np.zeros((n_a, m, T_x))
    y_pred = np.zeros((n_y, m, T_x))    
    # Initialize a_next
    a_next = a0    
    # loop over all time-steps
    for t in range(T_x):        
        # Update next hidden state, compute the prediction, get the cache
        a_next, yt_pred, cache = rnn_cell_forward(x[:,:,t], a_next, parameters)        
        # Save the value of the new "next" hidden state in a
        a[:,:,t] = a_next        
        # Save the value of the prediction in y
        y_pred[:,:,t] = yt_pred        
        # Append "cache" to "caches" 
        caches.append(cache)    
    # store values needed for backward propagation in cache
    caches = (caches, x)    
    return a, y_pred, caches

     计算示例如下:

640?wx_fmt=png

     以上便是本讲内容,接下来我们将继续深入RNN、GRU 和 LSTM 在自然语言处理等方面的应用学习。

参考资料:

deeplearningai.com

https://blog.csdn.net/qq_28743951/article/details/78974058

往期精彩:


一个数据科学从业者的学习历程

640?

640?wx_fmt=jpeg

长按二维码.关注机器学习实验室

640?wx_fmt=jpeg

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值