tensorflow lstm从隐状态到预测值_LSTM原理及实战

一、基本原理

要理解LSTM可以先来看一下什么是RNN?

1.1 RNN原理

一般来说,RNN的输入和输出都是一个序列,分别记为

,同时
的取值不仅与
有关还与序列中更早的输入有关(序列中的第t个元素我们叫做序列在time_step=t时的取值)。更直观的理解可看下图:

6a4390c344c1660c6cff5970223935c0.png

把上图用公式表达就是:

训练RNN需要用BPTT去优化,但是当序列过长时很容易引起梯度爆炸或梯度消失现象。

1.2 LSTM原理

LSTM是一种特殊的RNN,主要通过三个门控逻辑实现(遗忘、输入、输出)。它的提出就是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。它的原理可看如下4图:

9696fac83d90bc74760393bcec94f99f.png
图1 遗忘门计算逻辑

54c10e8dffe326532ebac8551f6fff22.png
图2 输入门及当前候选细抱状态

f64c478a739ed9cb5dfc81ca06ecafcd.png
图3 当前细抱状态

cded159c6caaf6490fd92d0750fe295b.png
图4 输出门及最终输出

LSTM的典型应用(基于hidden state去预测):基于语言模型预测下一个单词、词性标注等。实际应用中还有很多LSTM的变种,典型的是GRU(参考文献2)。

二、基于Pytorch的实战

2.1 核心API

torch.nn.LSTM(*args,**kwargs)

其构造器的参数列表如下:

  • input_size – 每个time step中其输入向量
    的维度。
  • hidden_size – 每个time step中其隐藏状态向量
    的维度。
  • num_layers – 每个time step中其纵向有几个LSTM单元,默认为1。如果取2,第二层的
    是第一层的
    ,有时也会加一个dropout因子。
  • bias – 如果为False,则计算中不用偏置,默认为True。
  • batch_first –若为True,则实际调用时input和output张量格式为(batch, seq, feature),默认为False。
  • dropout – 是否加dropout,Default: 0。
  • bidirectional – 是否为双向LSTM,Default: False。

定义了模型,实际调用按如下方式:

lstm = nn.LSTM(3, 3)

# Inputs: input, (h_0, c_0) 
# Outputs: output, (h_n, c_n)
Outputs=lstm(Inputs)

注意上述代码中:
1)h_0, c_0分别代表batch中每个元素的hidden state和cell state的初始化值。

2)h_n, c_n分别代表当t = seq_len时,hidden state和cell state的值。

3)如果batch_first=False时,input格式为:(seq_len, batch=1, input_size),output格式为:(seq_len, batch=1, num_directions * hidden_size)。但是当batch_first=True时,input的格式变为:(batch_size, seq_len, input_size),而output的格式变为:(batch_size, seq_len, num_directions * hidden_size)。

2.2 LSTM实战

1)简单demo

# simple demo

2)完整demo

###数据准备

【参考资料】

1、pytorch LSTM API:https://pytorch.org/docs/stable/nn.html?highlight=lstm#torch.nn.LSTM。

2、Understanding LSTM Networks:http://colah.github.io/posts/2015-08-Understanding-LSTMs/。

3、一文搞懂RNN:https://zhuanlan.zhihu.com/p/30844905。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值