LSTM网络算法的原理与代码实现

目录

1.LSTM(Long Short Term Memory)网络流程图

 2.长短期记忆网络(LSTM)

 2.1输入门、忘记门和输出门

 2.2候选记忆元

2.3记忆元

2.4隐状态

 3.LSTM网络的代码实现

 4. pytorch简洁实现版本-LSTM网络模型


 

1.LSTM(Long Short Term Memory)网络流程图

 2.长短期记忆网络(LSTM)

LSTM的设计灵感来源于计算机的逻辑门

  • 输入门、忘记门和输出门

 2.1输入门、忘记门和输出门

 

 2.2候选记忆元

 

2.3记忆元

2.4隐状态

 3.LSTM网络的代码实现

#导包
import torch
from torch import nn
import dltools


#加载数据
#声明变量
batch_size, num_steps = 32, 35
#获取训练数据的迭代器,词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size, num_steps)


#封装函数:实现初始化模型参数
def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size
    
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    def three():
        return (normal((num_inputs, num_hiddens)),
               normal((num_hiddens, num_hiddens)),
               torch.zeros(num_hiddens, device=device))
    
    W_xi, W_hi, b_i = three() # 输入门参数
    W_xf, W_hf, b_f = three() # 遗忘门参数
    W_xo, W_ho, b_o = three() # 输出门参数
    W_xc, W_hc, b_c = three() # 候选记忆元参数
    
    # 输出层
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]
    
    for param in params:
        param.requires_grad_(True)
        
    return params


#初始化隐藏状态state和记忆元C_tilda¶
#初始化因那个状态和记忆元
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))


#定义LSTM的主体结构
#定义lstm主体结构
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params
    (H, C) = state  #获取隐藏state与记忆元
    
    outputs = []
    
    #进行前向传播计算
    for X in inputs:   #循环输入数据
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q   #输出预测
        outputs.append(Y)
    #torch.cat(outputs, dim=0)按照0维度合并
    return torch.cat(outputs, dim=0), (H, C)


##测试代码的可行性
X = torch.arange(10).reshape((2, 5))
num_hiddens = 512
net = dltools.RNNModelScratch(len(vocab), num_hiddens, dltools.try_gpu(), get_lstm_params, init_lstm_state, lstm)
state = net.begin_state(X.shape[0], dltools.try_gpu())
Y, new_state = net(X.to(dltools.try_gpu()), state)


## 训练和预测
vocab_size, num_hiddens, device =  len(vocab), 256, dltools.try_gpu()
num_epochs, lr = 500, 3
model = dltools.RNNModelScratch(vocab_size, num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 

 4. pytorch简洁实现版本-LSTM网络模型

# pytorch简洁实现版本.
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = dltools.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 

5.知识点个人理解

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值