使用 torch.nn.LSTM 可以方便的构建 LSTM,不熟悉 LSTM 的可以先看这两篇文章:
RNN:https://blog.csdn.net/yizhishuixiong/article/details/105588233
LSTM:https://blog.csdn.net/yizhishuixiong/article/details/105572296
下面详细讲述 torch.nn.LSTM 的使用
torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False)
- input_size:输入数据的大小;
- hidden_size:隐藏层的大小(节点数量),输出向量的维度等于节点数量;
- num_layers:recurrent layer 的数量(默认为1);
- bias:默认为 True;
- batch_first:输入输出维度的第一维是否为 batch_size。若为True,则 batch_size 在第一维,若为 False(默认),则 batch_size 在第二维;
- dropout:若非0,则在除了最后一层的各层都使用 dropout 层,默认为0;
- bidirectional:若为 True,则使用双向 LSTM,默认为 False;
LSTM 的输入:input,(h_0,c_0)
- input:输入数据,shape 为(句子长度seq_len, 句子数量batch, 每个单词向量的长度input_size);
- h_0:默认为0,shape 为(num_layers * num_directions单向为1双向为2, batch, 隐藏层节点数hidden_size);
- c_0:默认为0,shape 为(num_layers * num_directions, batch, hidden_size);
LSTM 的输出:output,(h_n,c_n)
- output:输出的 shape 为(seq_len, batch, num_directions * hidden_size);
- h_n:shape 为(num_layers * num_directions, batch, hidden_size);
- c_n:shape 为(num_layers * num_directions, batch, hidden_size);
代码演示
import torch
import torch.nn as nn
rnn = nn.LSTM(10, 20, 3) # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3
input = torch.randn(8, 3, 10) # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10
h_0, c_0 = torch.randn(3, 3, 20), torch.randn(3, 3, 20)
output, (h_n, c_n) = rnn(input, (h_0, c_0))
print("input.shape:", input.shape)
print("h_0.shape:", h_0.shape)
print("c_0.shape:", c_0.shape)
print("*" * 50)
print("output.shape:", output.shape)
print("h_n.shape:", h_n.shape)
print("c_n.shape:", c_n.shape)
双向:
import torch
import torch.nn as nn
rnn = nn.LSTM(10, 20, 3, bidirectional=True) # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3,双向
input = torch.randn(8, 3, 10) # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10
h_0, c_0 = torch.randn(6, 3, 20), torch.randn(6, 3, 20)
output, (h_n, c_n) = rnn(input, (h_0, c_0))
print("input.shape:", input.shape)
print("h_0.shape:", h_0.shape)
print("c_0.shape:", c_0.shape)
print("*" * 50)
print("output.shape:", output.shape)
print("h_n.shape:", h_n.shape)
print("c_n.shape:", c_n.shape)