pytorch中LSTM的输入与输出理解

在阅读本篇博客之前希望你在LSTM方面有一定的知识储备,熟悉LSTM网络的内部结构,方便更好的理解pytorch中有关LSTM相关的api。

一、参数理解

这里我根据lstm的结构定义了一些参数,参数具体含义可以看注释

batch_size = 10 #每个batch的大小
seq_len = 2000 #模仿输入到LSTM的句子长度
input_size = 30 #lstm中输入的维度
hidden_size = 18 #lstm中隐藏层神经元的个数
num_layers = 2 # 有多少层lstm

二、数据准备

input = torch.randn(batch_size,seq_len,input_size)

三、LSTM

1、batch_first=True

pytorch中lstm输入和输出分为两种形式,一种是batch优先,另外一种则是batch第二,具体情况是指定lstm种参数batch_first=True,batch_first默认是False对应batch第二的情况,使用中我们一般将batch_first设置为True,采用batch优先的方式。如下

lstm = nn.LSTM(input_size=embedding_dim,hidden_size=hidden_size,num_layers=num_layers,batch_first=True)#,batch_first=True
out,(hn,cn) = lstm(input)
print(out.size())
print('*'*100)
print(hn.size())
print('*'*100)
print(cn.size())
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值