1、模型原理
LSTM(long short-term memory)是RNN的一种变体,RNN由于梯度消失的原因只能有短期记忆,LSTM网络通过精妙的门控制将短期记忆与长期记忆结合起来,并且一定程度上解决了梯度消失的问题。
所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。
标准RNN网络的结构:
LSTM 同样是这样的结构,但是重复的模块拥有一个不同的结构。不同于 单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。
LSTM网络的结构:
2、代码实现
本文以简单数据集和网络结构实现TextLSTM,训练集中包括10个单词,用每个单词的前三个字符去预测最后一个字符,目的是便于读者更好的理解该网络的原理。
1. 导入需要的库,设置数据类型
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
dtype = torch.FloatTensor
2. 创建数据和字典
seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']
char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']
word_dict = {
w:n for n ,w in enumerate(char_arr)}
number_dict = {
n:w for n, w in enumerate(char_arr)}
n_class = len(char_arr) //number of class(=number of vocab)