Pytorch LSTM函数参数解释 图解

Pytorch LSTM函数参数解释 图解

笔者最近在写有关LSTM的代码,但是对于nn.LSTM函数中的有些参数还是不明白其具体含义,学习过后在此记录。

为了方便说明,我们先解释函数参数的作用,接着对应图片来说明每个参数的具体含义。

torch.nn.LSTM函数

LSTM的函数

class torch.nn.LSTM(args, *kwargs)
	# 主要参数
	# input_size – 输入的特征维度
	# hidden_size – 隐状态的特征维度
	# num_layers – 层数(和时序展开要区分开)
	# bias – 如果为False,那么LSTM将不会使用偏置,默认为True。
	# batch_first – 如果为True,那么输入和输出Tensor的形状为(batch, seq_len, input_size)
	# dropout – 如果非零的话,将会在RNN的输出上加个dropout,最后一层除外。
	# bidirectional – 如果为True,将会变成一个双向RNN,默认为False。

LSTM的输入维度(seq_len, batch, input_size) 如果batch_first为True,则输入形状为(batch, seq_len, input_size)
  seq_len是文本的长度;
  batch是批次的大小;
  input_size是每个输入的特征纬度(一般是每个字/单词的向量表示;

LSTM的输出维度(seq_len, batch, hidden_size * num_directions)
  seq_len是文本的长度;
  batch是批次的大小;
  hidden_size是定义的隐藏层长度
  num_directions指的则是如果是普通LSTM该值为1; Bi-LSTM该值为2

当然,仅仅用文本来说明则让人感到很懵逼,所以我们使用图片来说明。

图解LSTM函数

我们常见的LSTM的图示是这样的:
LSTM常见说明
但是这张图很具有迷惑性,让我们不易理解LSTM各个参数的意义。具体将上图中每个单元展开则为下图所示:
在这里插入图片描述
input_size: 图1中 x i x_i xi与图2中绿色节点对应,而绿色节点的长度等于input_size(一般是每个字/单词的向量表示)。
hidden_size: 图2中黄色节点的数量
num_layers: 图2中黄色节点的层数(该图为1)

引用图片

LSTM参数的问题: 链接.

### LSTM网络结构详解 #### 背景介绍 循环神经网络(Recurrent Neural Network, RNN)被广泛用于处理序列数据,但由于其固有的缺陷,在训练过程中容易遭遇梯度消失或梯度爆炸问题[^3]。这些问题限制了RNN捕捉长时间依赖关系的能力。 #### LSTM的核心概念 为了解决上述问题,Hochreiter和Schmidhuber于1997年提出了长短期记忆网络(Long Short-Term Memory, LSTM)。LSTM通过引入一种特殊的单元结构——记忆细胞(memory cell),以及三个门控机制(输入门、遗忘门和输出门),实现了对信息的选择性存储与删除功能。 #### 结构组成 LSTM的主要组成部分包括以下几个方面: 1. **记忆细胞(Cell State)** 记忆细胞是LSTM的关键组件之一,它贯穿整个时间步并控制着信息流的方向。这种设计允许重要的信息得以长期保存,而无关的信息则会被逐步移除。 2. **输入门(Input Gate)** 输入门决定了当前时刻的输入有多少应该写入到记忆细胞中去。这一过程通常涉及激活函数sigmoid和tanh的操作: \[ i_t = \sigma(W_{xi}x_t + W_{hi}h_{t-1} + b_i) \] \[ g_t = tanh(W_xg x_t + W_hg h_{t-1} + b_g) \] 这里 \(i_t\) 表示输入门的状态向量,\(g_t\) 是候选值向量。 3. **遗忘门(Forget Gate)** 遗忘门的作用在于决定上一时刻的记忆细胞中有多少比例的内容需要被保留下来或者丢弃掉。具体表达式如下所示: \[ f_t = \sigma(W_{xf}x_t + W_{hf}h_{t-1} + b_f) \] 其中,\(\sigma\) 函数返回的是介于0至1之间的数值,代表遗忘的程度。 4. **输出门(Output Gate)** 输出门负责筛选出最终要作为本节点输出的部分,并将其传递给下一个隐藏层节点或者是预测目标。计算方式如下: \[ o_t = \sigma(W_{xo}x_t + W_{ho}h_{t-1} + b_o) \] \[ c_t = f_t * c_{t-1} + i_t * g_t \] \[ h_t = o_t * tanh(c_t) \] 以上便是标准版LSTM的基本运算逻辑框架图解说明。 ```python import torch.nn as nn class LSTMModel(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size): super(LSTMModel, self).__init__() self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): out, _ = self.lstm(x) out = self.fc(out[:, -1, :]) return out ``` 此代码片段展示了一个简单的基于PyTorch实现的LSTM模型定义方法[^2]。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值