lstm理解与使用(pytorch为例)

一.lstm原理

可以详读:Understanding LSTM Networks
http://colah.github.io/posts/2015-08-Understanding-LSTMs/

1.结构图

在这里插入图片描述

2.公式
  • 遗忘门,决定哪些东西被遗忘:
    在这里插入图片描述

  • 输入门,决定状态Cell里面的更新:
    在这里插入图片描述

  • C t C_t Ct状态更新,与遗忘门和输入们有关
    在这里插入图片描述

  • 输出门,决定t时刻的输出:
    在这里插入图片描述
    其中, h t − 1 h_{t-1} ht1为上个时间节点t-1时刻的输出,输出的维度可以设定,例如pytorch中:

rnn = nn.LSTM(10, 20, 2)  #(input_size,hidden_size,num_layers)

hidden_size的维度,其实就是设定的 h h h的维度,并且每一个时间节点t,都会有一个输出,一般情况采用最后时刻的输出,当然也可以利用各个时间点的hidden层的特征,来综合做判断,attention就是这样子的。

二、lstm的pytorch使用与理解

2.1 单向lstm的使用
rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)#(input_size,hidden_size,num_layers)
input = torch.randn(5, 3, 10)#(seq_len, batch, input_size)
h0 = torch.randn(2, 3, 20) #(num_layers,batch,output_size)
c0 = torch.randn(2, 3, 20) #(num_layers,batch,output_size)
output, (hn, cn) = rnn(input, (h0, c0))
output.shape #(seq_len, batch, output_size)
torch.Size([5, 3, 20])
hn.shape #(num_layers, batch, output_size)
torch.Size([2, 3, 20])
cn.shape #(num_layers, batch, output_size)
torch.Size([2, 3, 20])

如何理解呢?可以看下图以及解释:
在这里插入图片描述

  • output:对于每一个step,相当于也就是seq_len中的每一步,都有一个output_size维度的特征输出,所以output的维度是5,3,20(seq_len, batch, output_size).
  • hc其实是最后一个时间节点t的隐藏层的特征,因为lstm中设置了num_layers=2,所以每一层lstm最后一个时间节点都会有一个output_size维度特征的输出,所以他的输出维度为2, 3, 20(num_layers, batch, output_size)
  • cn与hc相同
    上图来源于:LSTM神经网络输入输出究竟是怎样的?https://www.zhihu.com/question/41949741/answer/318771336
2.2 双向lstm的使用
rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2,bidirectional=True)#(input_size,hidden_size,num_layers)
input = torch.randn(5, 3, 10)#(seq_len, batch, input_size)
h0 = torch.randn(4, 3, 20) #(num_layers,batch,output_size)
c0 = torch.randn(4, 3, 20) #(num_layers,batch,output_size)
output, (hn, cn) = rnn(input, (h0, c0))
output.shape #(seq_len, batch, output_size*2)
torch.Size([5, 3, 40])
hn.shape #(num_layers*2, batch, output_size)
torch.Size([4, 3, 20])
cn.shape #(num_layers*2, batch, output_size)
torch.Size([4, 3, 20])
  • 其实就是反向加了一层lstm,然后输出的时候再concat起来,所以维度上会*2,bilstm如下图所示:

在这里插入图片描述

三、pytorch变长lstm的使用

1.pytorch中如何处理RNN输入变长序列padding

https://zhuanlan.zhihu.com/p/34418001

2.教你几招搞定 LSTMs 的独门绝技(附代码)

https://zhuanlan.zhihu.com/p/40391002

  • 43
    点赞
  • 206
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值