PyTorch 中 LSTM 的 output、h_n 和 c_n 之间的关系

LSTM 简介

  • 官方文档:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
  • h_n:最后一个时间步的输出,即 h_n = output[:, -1, :](一般可以直接输入到后续的全连接层,在 Keras 中通过设置参数 return_sequences=False 获得)
  • c_n:最后一个时间步 LSTM cell 的状态(一般用不到)

实例

  • 实例:根据红框可以直观看出,h_n 是最后一个时间步的输出,即是 h_n = output[:, -1, :],如何还是无法直观理解,直接看如下截图,对照代码可以非常容易看出它们的关系

  • 实例代码:

>>> import torch
>>> import torch.nn as nn
>>> rnn = nn.LSTM(input_size=2, hidden_size=3, batch_first=True)
>>> input = torch.randn(5,4,2)
>>> h0 = torch.randn(1, 5, 3)
>>> c0 = torch.randn(1, 5, 3)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[-0.1046, -0.0316, -0.2261],
         [ 0.0702,  0.0756, -0.2856],
         [ 0.1146,  0.0666, -0.1841],
         [ 0.1137,  0.0508, -0.3966]],

        [[ 0.3702, -0.1192, -0.3513],
         [ 0.3964, -0.0513, -0.1744],
         [ 0.3144,  0.0564, -0.2114],
         [ 0.3056,  0.1312, -0.1656]],

        [[ 0.1581, -0.3509,  0.0068],
         [ 0.2391, -0.0308,  0.0773],
         [ 0.2420,  0.0607, -0.0652],
         [ 0.2854,  0.0656, -0.0306]],

        [[-0.0562, -0.0229,  0.1600],
         [-0.2156, -0.0006,  0.0898],
         [ 0.0700,  0.2200, -0.0068],
         [ 0.1903,  0.3120,  0.0253]],

        [[ 0.1025, -0.0167,  0.3068],
         [ 0.2028,  0.0652,  0.1738],
         [ 0.3324,  0.1645,  0.1908],
         [ 0.2594,  0.0896, -0.0507]]], grad_fn=<TransposeBackward0>)
>>> hn
tensor([[[ 0.1137,  0.0508, -0.3966],
         [ 0.3056,  0.1312, -0.1656],
         [ 0.2854,  0.0656, -0.0306],
         [ 0.1903,  0.3120,  0.0253],
         [ 0.2594,  0.0896, -0.0507]]], grad_fn=<StackBackward>)
>>> cn
tensor([[[ 0.3811,  0.2079, -0.7427],
         [ 0.9059,  0.2375, -0.3272],
         [ 0.5819,  0.1175, -0.0766],
         [ 0.5059,  0.5022,  0.0446],
         [ 0.7312,  0.2270, -0.0970]]], grad_fn=<StackBackward>)
>>> output[-1]
tensor([[ 0.1025, -0.0167,  0.3068],
        [ 0.2028,  0.0652,  0.1738],
        [ 0.3324,  0.1645,  0.1908],
        [ 0.2594,  0.0896, -0.0507]], grad_fn=<SelectBackward>)
>>> output[:,:,-1]
tensor([[-0.2261, -0.2856, -0.1841, -0.3966],
        [-0.3513, -0.1744, -0.2114, -0.1656],
        [ 0.0068,  0.0773, -0.0652, -0.0306],
        [ 0.1600,  0.0898, -0.0068,  0.0253],
        [ 0.3068,  0.1738,  0.1908, -0.0507]], grad_fn=<SelectBackward>)
>>> output[:,-1,:]
tensor([[ 0.1137,  0.0508, -0.3966],
        [ 0.3056,  0.1312, -0.1656],
        [ 0.2854,  0.0656, -0.0306],
        [ 0.1903,  0.3120,  0.0253],
        [ 0.2594,  0.0896, -0.0507]], grad_fn=<SliceBackward>)
>>> output[:,-1,:].shape
torch.Size([5, 3])
>>> output.shape
torch.Size([5, 4, 3])
>>> hn.shape
torch.Size([1, 5, 3])
>>> cn.shape
torch.Size([1, 5, 3])
  • 8
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
PyTorch ,可以通过 `nn.LSTM` 类来创建 LSTM 模型。 下面是一个简单的 LSTM 模型的示例代码: ```python import torch.nn as nn class LSTMModel(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size): super(LSTMModel, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) out, _ = self.lstm(x, (h0, c0)) out = self.fc(out[:, -1, :]) return out ``` 在上面的代码,`nn.LSTM` 的第一个参数是输入特征的维度,第二个参数是隐藏层的维度,第三个参数是 LSTM 层数,`batch_first=True` 表示输入的第一个维度是 batch 的大小。 在 `forward` 方法,我们首先定义了初始的隐藏状态 `h0` 和记忆状态 `c0`,然后将输入 `x` 和初始状态传入 `nn.LSTM` ,得到输出 `out`。由于我们只关心最后一个时间步的输出,所以我们取 `out[:, -1, :]`,然后将其传入全连接层 `nn.Linear` 得到最终的输出。 需要注意的是,在使用 LSTM 模型时,需要将输入数据转换为 PyTorch 的张量,并将其发送到 GPU 上进行计算。例如,`x = torch.tensor(x).float().to(device)`。其 `x` 是输入数据,`device` 是计算设备。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

csdn-WJW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值