6.3 LSTM的记忆能力实验
长短期记忆网络(Long Short-Term Memory Network,LSTM)是一种可以有效缓解长程依赖问题的循环神经网络.LSTM 的特点是引入了一个新的内部状态(Internal State)c∈RD 和门控机制(Gating Mechanism).不同时刻的内部状态以近似线性的方式进行传递,从而缓解梯度消失或梯度爆炸问题.同时门控机制进行信息筛选,可以有效地增加记忆能力.例如,输入门可以让网络忽略无关紧要的输入信息,遗忘门可以使得网络保留有用的历史信息.在上一节的数字求和任务中,如果模型能够记住前两个非零数字,同时忽略掉一些不重要的干扰信息,那么即时序列很长,模型也有效地进行预测.
使用LSTM模型重新进行数字求和实验,验证LSTM模型的长程依赖能力。
LSTM 模型在第 t 步时,循环单元的内部结构如图所示.
6.3.1 模型构建
在本实验中,我们将使用第6.1.2.4节中定义Model_RNN4SeqClass模型,并构建 LSTM 算子.只需要实例化 LSTM 算,并传入Model_RNN4SeqClass模型,就可以用 LSTM 进行数字求和实验.
6.3.1.1 LSTM层
LSTM层的代码与SRN层结构相似,只是在SRN层的基础上增加了内部状态、输入门、遗忘门和输出门的定义和计算。这里LSTM层的输出也依然为序列的最后一个位置的隐状态向量。代码实现如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
# 声明LSTM和相关参数
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, Wi_attr=None, Wf_attr=None, Wo_attr=None, Wc_attr=None,
Ui_attr=None, Uf_attr=None, Uo_attr=None, Uc_attr=None, bi_attr=None, bf_attr=None,
bo_attr=None, bc_attr=None):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# 初始化模型参数
if Wi_attr==None:
Wi = torch.zeros(size=[input_size, hidden_size], dtype=torch.float32)
else:
Wi = torch.tensor(Wi_attr, dtype=torch.float32)
self.W_i = torch.nn.Parameter(Wi)
if Wf_attr==None:
Wf = torch.zeros(size=[input_size, hidden_size], dtype=torch.float32)
else:
Wf = torch.tensor(Wf_attr, dtype=torch.float32)
self.W_f = torch.nn.Parameter(Wf)
if Wo_attr==None:
Wo = torch.zeros(size=[input_size, hidden_size], dtype=torch.float32)
else:
Wo = torch.tensor(Wo_attr, dtype=torch.float32)
self.W_o = torch.nn.Parameter(Wo)
if Wc_attr==None:
Wc = torch.zeros(size=[input_size, hidden_size], dtype=torch.float32)
else:
Wc = torch.tensor(Wc_attr, dtype=torch.float32)
self.W_c = torch.nn.Parameter(Wc)
if Ui_attr==None:
Ui = torch.zeros(size=[hidden_size, hidden_size], dtype=torch.float32)
else:
Ui = torch.tensor(Ui_attr, dtype=torch.float32)
self.U_i = torch.nn.Parameter(Ui)
if Uf_attr == None:
Uf = torch.zeros(size=[hidden_size, hidden_size], dtype=torch.float32)
else:
Uf = torch.tensor(Uf_attr, dtype=torch.float32)
self.U_f = torch.nn.Parameter(Uf)
if Uo_attr == None:
Uo = torch.zeros(size=[hidden_size, hidden_size], dtype=torch.float32)
else:
Uo = torch.tensor(Uo_attr, dtype=torch.float32)
self.U_o = torch.nn.Parameter(Uo)
if Uc_attr == None:
Uc = torch.zeros(size=[hidden_size, hidden_size], dtype=torch.float32)
else:
Uc = torch.tensor(Uc_attr, dtype=torch.float32)
self.U_c = torch.nn.Parameter(Uc)
if bi_attr == None:
bi = torch.zeros(size=[1,hidden_size], dtype=torch.float32)
else:
bi = torch.tensor(bi_attr, dtype=torch.float32)
self.b_i = torch.nn.Parameter(bi)
if bf_attr == None:
bf = torch.zeros(size=[1,hidden_size], dtype=torch.float32)
else:
bf = torch.tensor(bf_attr, dtype=torch.float32)
self.b_f = torch.nn.Parameter(bf)
if bo_attr == None:
bo = torch.zeros(size=[1,hidden_size], dtype=torch.float32)
else:
bo = torch.tensor(bo_attr, dtype=torch.float32)
self.b_o = torch.nn.Parameter(bo)
if bc_attr == None:
bc = torch.zeros(size=[1,hidden_size], dtype=torch.float32)
else:
bc = torch.tensor(bc_attr, dtype=torch.float32)
self.b_c = torch.nn.Parameter(bc)
# 初始化状态向量和隐状态向量
def init_state(self, batch_size):
hidden_state = torch.zeros(size=[batch_size, self.hidden_size], dtype=torch.float32)
cell_state = torch.zeros(size=[batch_size, self.hidden_size], dtype=torch.float32)
return hidden_state, cell_state
# 定义前向计算
def forward(self, inputs, states=None):
# inputs: 输入数据,其shape为batch_size x seq_len x input_size
batch_size, seq_len, input_size = inputs.shape
# 初始化起始的单元状态和隐状态向量,其shape为batch_size x hidden_size
if states is None:
states = self.init_state(batch_size)
hidden_state, cell_state = states
# 执行LSTM计算,包括:输入门、遗忘门和输出门、候选内部状态、内部状态和隐状态向量
for step in range(seq_len):
# 获取当前时刻的输入数据step_input: 其shape为batch_size x input_size
step_input = inputs[:, step, :]
# 计算输入门, 遗忘门和输出门, 其shape为:batch_size x hidden_size
I_gate = F.sigmoid(torch.matmul(step_input, self.W_i) + torch.matmul(hidden_state, self.U_i) + self.b_i)
F_gate = F.sigmoid(torch.matmul(step_input, self.W_f) + torch.matmul(hidden_state, self.U_f) + self.b_f)
O_gate = F