1. 长短时记忆神经网络
1.1 长短时记忆神经网络
- 长短时记忆神经网络(Long Short Term Memory, LSTM )是一种RNN特殊的类型,可以学习长期依赖信息。在很多问题上,LSTM都取得巨大的成功,并得到了广泛的应用。
- LSTM能够有效捕捉长序列之间的语义关联, 缓解梯度消失或爆炸现象. 同时LSTM的结构更复杂, 它的核心结构可以分为四个部分去解析,具体包括遗忘门、输入门、细胞状态、输出门
1.2 LSTM的网络结构
-
遗忘门:LSTM的遗忘门通过sigmiod函数决定哪些信心会被遗忘,经过sigmoid函数,会输出0~1之间的一个值,这个值会和前一次的细胞状态进行点乘,从而决定遗忘或者保留
-
输入门:LSTM的输入门决定哪些新的信息会被保留,这个过程有两步:
- 输入信息经过sigmoid层决定哪些信息会被更新
- tanh会创出一个新的候选向量,后续会被添加到细胞状态中
-
细胞状态更新:
- 旧的细胞状态和遗忘门结果相乘
- 然后加上输入门和tanh相乘的结果
-
输出门:LSTM的输出决定哪些信息会被输出,同样这个输出经过变换之后会通过sigmoid函数的结果来决定那些细胞状态会被输出。
1.3 步骤
- 步骤一:导入工具库
import torch
import torch.nn as nn
- 步骤二:LSTM网络搭建
class LstmModel(nn.Module):
def __init__(self):
super(LstmModel,self).__init__()
self.rnn = nn.LSTM(
input_size=1,
hidden_size=32,
num_layers=1
)
self.out = nn.Linear(32,1)
def forward(self,x,h):
# x (time_step, batch_size,input_size)
out,h = self.rnn(x,h)
prediction = self.out(out)
return prediction,h
- 步骤三:输出模型结构
rnn = LstmModel()
print(rnn)
1.4 运行结果
运行结果:
D:/Users/tarena/PycharmProjects/nlp/unit30/lstm_model.py
LstmModel(
(rnn): LSTM(1, 32)
(out): Linear(in_features=32, out_features=1, bias=True)
)
Process finished with exit code 0