Pytorch nn.LSTM 使用注意事项
Pytorch框架已经实现了LSTM模型,直接使用nn.LSTM即可,具体接口参数可以自行查看官方文档,这里讨论一下nn.LSTM接口的输入、输出张量维度,LSTM网络如下。
先上代码如下:
import torch
import torch.nn as nn
batch_size = 3
sentence_length = 2
input_size = 5
hidden_size = 12
num_layers = 2
bidirectional = False
batch_first = True
if batch_first:
input_tensor = torch.randn(batch_size, sentence_length, input_size)
else:
input_tensor = torch.randn(sentence_length, batch_size, input_size)
if bidirectional:
hidden_dim = hidden_size // 2
h0 = torch.randn(num_layers * 2, batch_size, hidden_dim)
c0 = torch.randn(num_layers * 2, batch_size, hidden_dim)
else:
hidden_dim = hidden_size
h0 = torch.randn(num_layers, batch_size, hidden_size)
c0 = torch.randn(num_layers, batch_size, hidden_size)
# nn.LSTM api默认 batch_first=False
lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, num_layers=num_layers, batch_first=batch_first, bidirectional=bidirectional)
output, (hn, cn) = lstm(input_tensor, (h0, c0))
print('output 维度:', output.shape)
print('hn 维度:', hn.shape)
print('cn 维度:', cn.shape)
1、单向lstm、batch_size在第一维
bidirectional = False
batch_first = True
#输出维度如下
output 维度: torch.Size([3, 2, 12])
hn 维度: torch.Size([2, 3, 12])
cn 维度: torch.Size([2, 3, 12])
2、单向lstm、batch_size在第二维
bidirectional = False
batch_first = False
#输出维度如下
output 维度: torch.Size([2, 3, 12])
hn 维度: torch.Size([2, 3, 12])
cn 维度: torch.Size([2, 3, 12])
3、双向lstm、batch_size在第一维
bidirectional = True
batch_first = True
output 维度: torch.Size([3, 2, 12])
hn 维度: torch.Size([4, 3, 6])
cn 维度: torch.Size([4, 3, 6])
4、双向lstm、batch_size在第二维
bidirectional = True
batch_first = False
output 维度: torch.Size([2, 3, 12])
hn 维度: torch.Size([4, 3, 6])
cn 维度: torch.Size([4, 3, 6])
通过上面的测试结论如下:
1、由于h0和hn,c0和cn的维度一定是相同的,同时h0和c0维度的第二维必须是batch_size;若是双向lstm,第一维是隐层层数的两倍;若是单向lstm,第一维是隐层层数。说明hn和cn是不受batch_first的影响。
2、输出张量维度,若指定batch_first为True,输出维度第一维则是batch_size;若batch_first为False,输出维度第二维则是batch_size。
总结
batch_first为False如下:
batch_first为True
完结,撒花。(虫哥的粉丝^~^)
本文地址:https://blog.csdn.net/FortuneLegend/article/details/126306636