Pytorch nn.LSTM 使用注意事项

Pytorch nn.LSTM 使用注意事项

Pytorch框架已经实现了LSTM模型,直接使用nn.LSTM即可,具体接口参数可以自行查看官方文档,这里讨论一下nn.LSTM接口的输入、输出张量维度,LSTM网络如下。

在这里插入图片描述

先上代码如下:
import torch
import torch.nn as nn

batch_size = 3
sentence_length = 2
input_size = 5
hidden_size = 12
num_layers = 2
bidirectional = False
batch_first = True

if batch_first:
    input_tensor = torch.randn(batch_size, sentence_length, input_size)
else:
    input_tensor = torch.randn(sentence_length, batch_size, input_size)

if bidirectional:
    hidden_dim = hidden_size // 2
    h0 = torch.randn(num_layers * 2, batch_size, hidden_dim)
    c0 = torch.randn(num_layers * 2, batch_size, hidden_dim)
else:
    hidden_dim = hidden_size
    h0 = torch.randn(num_layers, batch_size, hidden_size)
    c0 = torch.randn(num_layers, batch_size, hidden_size)

# nn.LSTM api默认 batch_first=False
lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, num_layers=num_layers, batch_first=batch_first, bidirectional=bidirectional)

output, (hn, cn) = lstm(input_tensor, (h0, c0))
print('output 维度:', output.shape)
print('hn 维度:', hn.shape)
print('cn 维度:', cn.shape)

1、单向lstm、batch_size在第一维

bidirectional = False
batch_first = True
#输出维度如下
output 维度: torch.Size([3, 2, 12])
hn 维度: torch.Size([2, 3, 12])
cn 维度: torch.Size([2, 3, 12])

2、单向lstm、batch_size在第二维

bidirectional = False
batch_first = False
#输出维度如下
output 维度: torch.Size([2, 3, 12])
hn 维度: torch.Size([2, 3, 12])
cn 维度: torch.Size([2, 3, 12])

3、双向lstm、batch_size在第一维

bidirectional = True
batch_first = True
output 维度: torch.Size([3, 2, 12])
hn 维度: torch.Size([4, 3, 6])
cn 维度: torch.Size([4, 3, 6])

4、双向lstm、batch_size在第二维

bidirectional = True
batch_first = False
output 维度: torch.Size([2, 3, 12])
hn 维度: torch.Size([4, 3, 6])
cn 维度: torch.Size([4, 3, 6])
通过上面的测试结论如下:
1、由于h0和hn,c0和cn的维度一定是相同的,同时h0和c0维度的第二维必须是batch_size;若是双向lstm,第一维是隐层层数的两倍;若是单向lstm,第一维是隐层层数。说明hn和cn是不受batch_first的影响。
2、输出张量维度,若指定batch_first为True,输出维度第一维则是batch_size;若batch_first为False,输出维度第二维则是batch_size。

总结

batch_first为False如下:

在这里插入图片描述

batch_first为True

在这里插入图片描述

完结,撒花。(虫哥的粉丝^~^)

本文地址:https://blog.csdn.net/FortuneLegend/article/details/126306636

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值