torch学习 (三十二):周杰伦歌词数据集与长短期记忆 (LSTM)

1 引入

  本文介绍一种常用的门控循环神经网络:长短期记忆 (long short-term memory, LSTM)。它比门控循环单元的结构稍微复杂一点。

2 长短期记忆

  LSTM引入了 3 3 3个门,即输入门 (input gate)、遗忘门 (forget gate)和输出门 (output gate),以及与隐藏状态形状相同的记忆细胞,从而记录额外的信息。

2.1 输入门、遗忘门和输出门

  与门控循环单元中的重置门和更新门一样,如下图,长短期记忆的门的输入均为当前时间步输入 X t \boldsymbol{X}_t Xt与上一时间步隐藏状态 H t − 1 \boldsymbol{H}_{t-1} Ht1,输出由激活函数为sigmoid函数的全连接层计算得到。如此一来,这 3 3 3个门元素的值域均为 [ 0 , 1 ] [0,1] [0,1]

  具体来说,假设隐藏单元个数为 h h h,给定时间步 t t t的小批量输入 X t ∈ R n × d \boldsymbol{X}_t\in\mathbb{R}^{n\times d} XtRn×d和上一时间步隐藏状态 H t − 1 ∈ R n × h \boldsymbol{H}_{t-1}\in\mathbb{R}^{n \times h} Ht1Rn×h。时间步 t t t的输入门 I t ∈ R n × h \boldsymbol{I}_t\in\mathbb{R}^{n\times h} ItRn×h、遗忘门 F t ∈ R n × h \boldsymbol{F}_t\in\mathbb{R}^{n\times h} FtRn×h和输出门 O t ∈ R n × h \boldsymbol{O}_t\in\mathbb{R}^{n\times h} OtRn×h分别计算如下:
I t = σ ( X t W x i + H t − 1 W h i + b i ) , F t = σ ( X t W x f + H t − 1 W h f + b f ) , O t = σ ( X t W x o + H t − 1 W h o + b o ) , \begin{aligned} \boldsymbol{I}_{t} &=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x i}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h i}+\boldsymbol{b}_{i}\right), \\ \boldsymbol{F}_{t} &=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x f}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h f}+\boldsymbol{b}_{f}\right), \\ \boldsymbol{O}_{t} &=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x o}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h o}+\boldsymbol{b}_{o}\right), \end{aligned} ItFtOt=σ(XtWxi+Ht1Whi+bi),=σ(XtWxf+Ht1Whf+bf),=σ(XtWxo+Ht1Who+bo),其中 W x i , W x f , W x o ∈ R d × h \boldsymbol{W}_{xi}, \boldsymbol{W}_{xf}, \boldsymbol{W}_{xo} \in \mathbb{R}^{d\times h} Wxi,Wxf,WxoRd×h W h i , W h f , W h o ∈ R h × h \boldsymbol{W}_{hi}, \boldsymbol{W}_{hf}, \boldsymbol{W}_{ho} \in \mathbb{R}^{h\times h} Whi,Whf,WhoRh×h是权重参数, b i , b f , b o ∈ R h × h \boldsymbol{b}_{i}, \boldsymbol{b}_{f}, \boldsymbol{b}_{o} \in \mathbb{R}^{h\times h} bi,bf,boRh×h是偏差参数。

2.2 候选记忆细胞

  长短期记忆需要计算候选记忆细胞 C ~ t \tilde{\boldsymbol{C}}_t C~t。它的计算与上面介绍的 3 3 3个门类似,但使用了值域在 [ − 1 , 1 ] [-1,1] [1,1]的tanh函数作为激活函数,如下图所示。

  具体来说,时间步 t t t的候选记忆细胞 C ~ t ∈ R n × h \tilde{\boldsymbol{C}}_t\in\mathbb{R}^{n\times h} C~tRn×h的计算为:
C ~ t = tanh ( X t W x c + H t − 1 W h c + b c ) , \tilde{\boldsymbol{C}}_t = \text{tanh}(\boldsymbol{X}_t\boldsymbol{W}_{xc}+\boldsymbol{H}_{t-1}\boldsymbol{W}_{hc}+\boldsymbol{b}_c), C~t=tanh(XtWxc+Ht1Whc+bc),其中 W x c ∈ R d × h \boldsymbol{W}_{xc}\in\mathbb{R}^{d\times h} WxcRd×h W h c ∈ R h × h \boldsymbol{W}_{hc}\in\mathbb{R}^{h\times h} WhcRh×h b c ∈ R 1 × h \boldsymbol{b}_c\in\mathbb{R}^{1\times h} bcR1×h是偏差参数。

2.3 记忆细胞

  通过元素值域在 [ 0 , 1 ] [0,1] [0,1]的输入门、遗忘门和输出门来控制隐藏状态中信息的流动,这一般也是通过使用按元素乘法 ⊙ \odot 来实现的。当前时间步记忆细胞 C t ∈ R n × h \boldsymbol{C}_t\in\mathbb{R}^{n \times h} CtRn×h的计算组合了上一时间步记忆细胞和当前时间步候选记忆细胞的信息,并通过遗忘门和输入门来控制信息的流动:
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t , \boldsymbol{C}_t = \boldsymbol{F}_t \odot \boldsymbol{C}_{t-1}+\boldsymbol{I}_t\odot\tilde{\boldsymbol{C}}_t, Ct=FtCt1+ItC~t,如下图所示。该设计可以应对RNN中的梯度衰减问题,并更好地捕捉时间序列中时间步间距较大依赖关系

2.4 隐藏状态

  有了记忆细胞以后,接下来可以通过输出门来控制从记忆细胞到隐藏状态 H t \boldsymbol{H}_t Ht的信息流动:
H t = O t ⊙ tanh ( C t ) . \boldsymbol{H}_t=\boldsymbol{O}_t\odot\text{tanh}(\boldsymbol{C}_t). Ht=Ottanh(Ct).这里的tanh函数确保隐藏状态元素值在 [ − 1 , 1 ] [-1,1] [1,1]之间。需要注意的是,当输出门近似 1 1 1时,记忆细胞信息将传递到隐藏状态供输出层使用;解决 0 0 0时,记忆细胞的信息只自己保留,如下图。

3 代码

  代码的主题框架与博客周杰伦歌词数据集测试循环神经网络中的架构一致,不同之处在于需要将mainpy文件中的以下语句替换为:

rnn_layer = get_rnn_layer(input_size=dict_size, hidden_size=hidden_size)
model = RNNModel(rnn_layer, dict_size).to(device)

↓↓↓

lstm_layer = nn.LSTM(input_size=dict_size, hidden_size=hidden_size)
model = RNNModel(lstm_layer, dict_size).to(device)

  输出如下:

epoch 50, perplexity 1.017165, time 1.56 sec
 - 分开始移动 回到当初爱你的时空 停格内容不忠 所有回忆对着我进攻       所有回忆对着我进攻    
 - 不分开 我知道这里很美但家乡的你更美走过了很多地方 我来到伊斯坦堡 就像是童话故事  有教堂有城堡 每天忙
epoch 100, perplexity 1.013933, time 1.57 sec
 - 分开始乡相信命运 感谢地心引力 让我碰到你 漂亮的让我面红的可爱女人 温柔的让我心疼的可爱女人 透明的让
 - 不分开 我对定会呵护著你 也逗你笑 你对我有多重要 我后悔没让你知道 安静的听你撒娇 看你睡著一直到老 就
epoch 150, perplexity 1.019680, time 1.55 sec
 - 分开始想像 爸和妈当年的模样 说著一口吴侬软语的姑娘缓缓走过外滩 消失的 旧时光 一九四三 在回忆 的路
 - 不分开 我知道这里很美但家乡的你更美原来我只想要你 陪我去吃汉堡  说穿了其实我的愿望就怎么小 就怎么每天
epoch 200, perplexity 2.351453, time 1.55 sec
 - 分开始想太你的 我都的可爱 不再考倒我 难过 我想躲 我不能再想 我不能再想 我不 我不 我不要再想 我
 - 不分开 那场外加油 你却还让我和狂的玩我 相思寄红豆 相思寄红豆走是人方的响尾蛇 无力的我爱你 爱情来的太
epoch 250, perplexity 1.014510, time 1.54 sec
 - 分开始打呼 管家是一只会说法语举止优雅的猪 吸血前会念约翰福音做为弥补 拥有一双蓝色眼睛的凯萨琳公主 专
 - 不分开 我用家二 在人海中 盲目跟从 别人的梦 全面放纵 恨没有用 疗伤止痛 不再感动 没有梦 痛不知轻重

致谢

感谢李沐、Aston Zhang等老师的这本《动手学深度学习》一书,为鄙人学习深度学习提供了很大的帮助。本文一系列关于深度学习的博客均无侵权之意,只为记录自己的深度学习历程。
  项目Github地址:https://github.com/ShusenTang/Dive-into-DL-PyTorch

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值