pytorch中的nn.LSTM模块参数详解

官网:https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM

Parameters(参数):

input_size :输入的维度

hidden_size:h的维度

num_layers:堆叠LSTM的层数,默认值为1

bias:偏置 ,默认值:True

batch_first: 如果是True,则input为(batch, seq, input_size)。默认值为:False(seq_len, batch, input_size

bidirectional :是否双向传播,默认值为False

输入

(input_size,hideen_size)

以训练句子为例子,假如每个词是100维的向量,每个句子含有24个单词,一次训练10个句子。那么batch_size=10,seq=24,input_size=100。(seq指的是句子的长度,input_size作为一个x_{t}的输入) ,所以在设置LSTM网络的过程中input_size=100。由于seq的长度是24,那么这个LSTM结构会循环24次最后输出预设的结果。如下图所示。

h的输出主要是看预设的hidden_size,这个hideen_size主要是下面LSTM公式中的各个W和b的维度设置,以g_{t}为例子,假设hideen_size为16,则W_{ig}为16*100,x_{t}为100*1,W_{hg}为16*16,h_{t-1}为16*1。

                                      

(num_layers,bidirectional)

两个的结构的区别如下图所示:

输出 

output :(seq_len, batch, num_directions * hidden_size)

h_n:(num_layers * num_directions, batch, hidden_size)

c_n :(num_layers * num_directions, batch, hidden_size)

如何输出,可以看上面那张图,图中有具体的。

此外,在输入的过程中,也可以给定h_{0},c_{0},如果没有给定那么默认为0。

(1)例子num_layers

import torch.nn as nn
import torch
x = torch.rand(10,24,100)
lstm = nn.LSTM(100,16,num_layers=2)
output,(h,c) = lstm(x)
print(output.size())
print(h.size())
print(c.size())

output:
torch.Size([24, 10, 16])
torch.Size([2, 10, 16])
torch.Size([2, 10, 16])

(2) 例子 bidirectional

import torch.nn as nn
import torch
x = torch.rand(10,24,100)
lstm = nn.LSTM(100,16,bidirectional=True)
output,(h,c) = lstm(x)
print(output.size())
print(h.size())
print(c.size())

output:
torch.Size([24, 10, 32])
torch.Size([2, 10, 16])
torch.Size([2, 10, 16])

(3) 例子 h0 c0

import torch.nn as nn
import torch
x = torch.rand(24,10,100) #seq,batch,input_size
h0 = torch.rand(1,10,16)# num_layers*num_directions, batch, hidden_size
c0 = torch.rand(1,10,16)
lstm = nn.LSTM(100,16)
output,(h,c) = lstm(x,(h0,c0))

 

评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Foneone

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值