Pytorch LSTM函数参数解释 图解
笔者最近在写有关LSTM的代码,但是对于nn.LSTM函数中的有些参数还是不明白其具体含义,学习过后在此记录。
为了方便说明,我们先解释函数参数的作用,接着对应图片来说明每个参数的具体含义。
torch.nn.LSTM函数
LSTM的函数
class torch.nn.LSTM(args, *kwargs)
# 主要参数
# input_size – 输入的特征维度
# hidden_size – 隐状态的特征维度
# num_layers – 层数(和时序展开要区分开)
# bias – 如果为False,那么LSTM将不会使用偏置,默认为True。
# batch_first – 如果为True,那么输入和输出Tensor的形状为(batch, seq_len, input_size)
# dropout – 如果非零的话,将会在RNN的输出上加个dropout,最后一层除外。
# bidirectional – 如果为True,将会变成一个双向RNN,默认为False。
LSTM的输入维度为 (seq_len, batch, input_size) 如果batch_first为True,则输入形状为(batch, seq_len, input_size)
seq_len是文本的长度;
batch是批次的大小;
input_size是每个输入的特征纬度(一般是每个字/单词的向量表示;
LSTM的输出维度为 (seq_len, batch, hidden_size * num_directions)
seq_len是文本的长度;
batch是批次的大小;
hidden_size是定义的隐藏层长度
num_directions指的则是如果是普通LSTM该值为1; Bi-LSTM该值为2
当然,仅仅用文本来说明则让人感到很懵逼,所以我们使用图片来说明。
图解LSTM函数
我们常见的LSTM的图示是这样的:
但是这张图很具有迷惑性,让我们不易理解LSTM各个参数的意义。具体将上图中每个单元展开则为下图所示:
input_size: 图1中
x
i
x_i
xi与图2中绿色节点对应,而绿色节点的长度等于input_size(一般是每个字/单词的向量表示)。
hidden_size: 图2中黄色节点的数量
num_layers: 图2中黄色节点的层数(该图为1)
引用图片
LSTM参数的问题: 链接.