目录
1.LSTM(Long Short Term Memory)网络流程图
1.LSTM(Long Short Term Memory)网络流程图
2.长短期记忆网络(LSTM)
LSTM的设计灵感来源于计算机的逻辑门
- 输入门、忘记门和输出门
2.1输入门、忘记门和输出门
2.2候选记忆元
2.3记忆元
2.4隐状态
3.LSTM网络的代码实现
#导包
import torch
from torch import nn
import dltools
#加载数据
#声明变量
batch_size, num_steps = 32, 35
#获取训练数据的迭代器,词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size, num_steps)
#封装函数:实现初始化模型参数
def get_lstm_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device) * 0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens, device=device))
W_xi, W_hi, b_i = three() # 输入门参数
W_xf, W_hf, b_f = three() # 遗忘门参数
W_xo, W_ho, b_o = three() # 输出门参数
W_xc, W_hc, b_c = three() # 候选记忆元参数
# 输出层
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
#初始化隐藏状态state和记忆元C_tilda¶
#初始化因那个状态和记忆元
def init_lstm_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device),
torch.zeros((batch_size, num_hiddens), device=device))
#定义LSTM的主体结构
#定义lstm主体结构
def lstm(inputs, state, params):
[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params
(H, C) = state #获取隐藏state与记忆元
outputs = []
#进行前向传播计算
for X in inputs: #循环输入数据
I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
C = F * C + I * C_tilda
H = O * torch.tanh(C)
Y = (H @ W_hq) + b_q #输出预测
outputs.append(Y)
#torch.cat(outputs, dim=0)按照0维度合并
return torch.cat(outputs, dim=0), (H, C)
##测试代码的可行性
X = torch.arange(10).reshape((2, 5))
num_hiddens = 512
net = dltools.RNNModelScratch(len(vocab), num_hiddens, dltools.try_gpu(), get_lstm_params, init_lstm_state, lstm)
state = net.begin_state(X.shape[0], dltools.try_gpu())
Y, new_state = net(X.to(dltools.try_gpu()), state)
## 训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, dltools.try_gpu()
num_epochs, lr = 500, 3
model = dltools.RNNModelScratch(vocab_size, num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
4. pytorch简洁实现版本-LSTM网络模型
# pytorch简洁实现版本.
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = dltools.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
5.知识点个人理解