002_wz_NLP_LSTM和LSTMCell

LSTM介绍

先去看RNN
在这里插入图片描述
LSTM使用“门”(sigmoid函数)解决了RNN不能记忆长距离信息的问题,使用累加方式的损失解决了RNN梯度爆炸/弥散的问题

LSTM中有三个门,分别是:
1.forget gate(遗忘门)
遗忘门用来控制记住以前记忆( C t − 1 C_{t-1} Ct1)的程度
在这里插入图片描述
2.input gate(输入门)
输入门用来控制记忆此时刻输入数据( X t X_t Xt)的程度
在这里插入图片描述
此时在经过了遗忘门和输入门,我们可以拿到本次的记忆数据 C t C_t Ct
在这里插入图片描述
这里介绍一下, f t f_t ft是遗忘门, i t i_t it是输入门, C t − 1 C_{t-1} Ct1相当于RNN中的 h t W h h h_tW_{hh} htWhh,~ C t C_t Ct相当于 x t W i h x_tW_{ih} xtWih,门在这里是起一个选择的作用,LSTM后面还需要再经过输出门,得到 h t h_t ht
在这里插入图片描述
3.output gate(输出门)
输出门用来控制输出 h t h_t ht的程度
在这里插入图片描述
下面是LSTM的简化整体流程
在这里插入图片描述
迭代公式
在这里插入图片描述
不同门关闭与开启,造成的效果
在这里插入图片描述

torch.nn.LSTM

输入参数:
input_size:即对数据做embedding的数据维度feature_len
hidden_size:LSTM的隐层维度
num_layers:LSTM网络的层数,默认为1层

LSTM的前向传播

out, (h_t, c_t) = lstm(x, (h_t0, c_t0)

x:输入数据,维度为(seq_len, batch_size, feature_len)
h/c:上一隐层的输出,维度为(num_layers, batch_size, hidden_size)
out:为每一时刻隐层输出的列表集合,形如[h_1, h_2, …, h_t],维度为(seq_len, batch_size, hidden_size)

LSTM代码验证

import torch

lstm = torch.nn.LSTM(input_size=100, hidden_size=20, num_layers=4)

x = torch.randn(10, 3, 100)
h_0 = torch.zeros(4, 3, 20)
c_0 = torch.zeros(4, 3, 20)

out, (h_t, c_t) = lstm(x, (h_0, c_0))

print(out.shape, h_t.shape, c_t.shape)

torch.Size([10, 3, 20]) torch.Size([4, 3, 20]) torch.Size([4, 3, 20])

torch.nn.LSTMCell

输入与LSTM相同,但是没有num_layers

LSTMCell的前向传播

h_t, c_t = lstmcell(x_t, (h_t0, c_t0))

x_t:单个输入数据,维度为(batch_size, feature_len)
h_t/c_t:隐层输出,维度为(batch_size, hidden_size)

LSTM代码验证

import torch

lstmcell = torch.nn.LSTMCell(input_size=100, hidden_size=20)

x = torch.randn(10, 3, 100)
h_0 = torch.zeros(3, 20)
c_0 = torch.zeros(3, 20)

for x_t in x:
    h_t, c_t = lstmcell(x_t, (h_0, c_0))

print(h_t.shape, c_t.shape)

torch.Size([3, 20]) torch.Size([3, 20])

import torch

lstmcell1 = torch.nn.LSTMCell(input_size=100, hidden_size=20)
lstmcell2 = torch.nn.LSTMCell(input_size=20, hidden_size=10)

x = torch.randn(10, 3, 100)
h_0 = torch.zeros(3, 20)
c_0 = torch.zeros(3, 20)

h_1 = torch.zeros(3, 10)
c_1 = torch.zeros(3, 10)

for x_t in x:
    h_0, c_0 = lstmcell1(x_t, (h_0, c_0))
    h_1, c_1 = lstmcell2(h_0, (h_1, c_1))

print(h_1.shape, c_1.shape)

torch.Size([3, 10]) torch.Size([3, 10])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值