代码文件
import torch
def lstm(X, state, params):
W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
H, C = state
# 遗忘门
F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f)
# 输入门
I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i)
C_tilde = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c)
# 更新单元状态
C = F * C + I * C_tilde
# 输出门
O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o)
H = O * torch.tanh(C)
# 输出层
Y = torch.mm(H, W_hq) + b_q
return Y, (H, C)
题目描述
任务描述
本关任务:通过学习长短时记忆网络相关知识,编写实现长短时记忆网络。
相关知识
为了完成本关任务,你需要掌握:
- 长短时记忆网络;
- 门结构;
- 长短时记忆网络实现。
长短时记忆网络
传统的循环神经网络受限于梯度爆炸与梯度消失问题,使得网络随着输入序列的增长,抖动变得更为剧烈,导致无法学习 。长短时记忆网络( Long Short Term Memory Network, LSTM )便是为了解决此问题而被设计提出。其核心思想是通过添加一个网络内部状态c
来记忆长期信息,这个新的状态我们称之为单元状态(Cell State),主要负责记忆长期信息。