1. 官方手册
2. output, h_n, c_n 之间的关系
首先,Pytorch中的LSTM有三个输出 output, hn, cn。
可以把hn理解为当前时刻,LSTM层的输出结果,而cn是记忆单元中的值,output则是包括当前时刻以及之前时刻所有hn的输出值
- 在只有单时间步的时候,
output = hn - 在多时间步时,
output可以看做是各个时间点hn的输出
3. 代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 随机生成数据
torch.manual_seed(1)
# 定义一层的LSTM神经网络
lstm = nn.LSTM(3, 2) # Input dim is 3, output dim is 3
# 生成一个序列长度为5的数据,一个点看成是一个时刻
inputs = [torch.randn(1, 3) for _ in range(5)]
# 初始化 隐藏状态,
# LSTM层数为1
# batch为1
# hidden_size:1
# initialize the hidden state.
hidden = (torch.randn(1, 1, 2),
torch.randn(1, 1, 2))
# 第一种获取输出的方式,循环多个时间步,得到每个时刻的输出
for i in inputs:
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(i.view(1, 1, -1), hidden)
# 第二种获取输出的方式,把输入格式变为:seq_len, batch, input_size 的三维张量
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
'''
tensor([[[ 0.3482, 1.1371, -0.3339]],
[[-1.4724, 0.7296, -0.1312]],
[[-0.6368, 1.0429, 0.4903]],
[[ 1.0318, -0.5989, 1.6015]],
[[-1.0735, -1.2173, 0.6472]]])
'''
hidden = (torch.randn(1, 1, 2), torch.randn(1, 1, 2)) # clean out hidden state
out, hidden = lstm(inputs, hidden)
print(out)
'''
tensor([[[-0.0468, 0.1818]],
[[-0.1173, 0.1622]],
[[-0.2076, 0.1286]],
[[-0.0474, 0.0851]],
[[-0.0185, 0.1172]]],
'''
print("h的最后一个值等于output的最后一个值")
print(hidden)
'''
h: tensor([[[-0.0185, 0.1172]]], grad_fn=<StackBackward>)
c: tensor([[[-0.0603, 0.1560]]], grad_fn=<StackBackward>)
'''