LSTM 简介
LSTM是RNN中一个较为流行的网络模块。主要包括输入,输入门,输出门,遗忘门,激活函数,全连接层(Cell)和输出。其结构如下:
上述公式不做解释,我们只要大概记得以下几个点就可以了:
- 当前时刻LSTM模块的输入有来自当前时刻的输入值,上一时刻的输出值,输入值和隐含层输出值,就是一共有四个输入值,这意味着一个LSTM模块的输入量是原来普通全连接层的四倍左右,计算量多了许多。
- 所谓的门就是前一时刻的计算值输入到sigmoid激活函数得到一个概率值,这个概率值决定了当前输入的强弱程度。 这个概率值和当前输入进行矩阵乘法得到经过门控处理后的实际值。
- 门控的激活函数都是sigmoid,范围在(0,1),而输出输出单元的激活函数都是tanh,范围在(-1,1)。
Pytorch实现如下:
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import math
class NaiveLSTM(nn.Module):
"""Naive LSTM like nn.LSTM"""
def __init__(self, input_size: int, hidden_size: int):
super(NaiveLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# input gate
self.w_ii = Parameter(Tensor(hidden_size, input_size))
self.w_hi = Parameter(Tensor(hidden_size, hidden_size))
self.b_ii = Parameter(Tensor(hidden_size, 1))
self.b_hi = Parameter(Tensor(hidden_size, 1))
# forget gate
self.w_if = Parameter(Tensor(hidden_size, input_size)