[NLP] Description and implementation of LSTM neural network.

18 篇文章 0 订阅
12 篇文章 0 订阅

0. Statement 🏫

Today I intend to move from an intuitive understanding of LSTM to its implementation with PyTorch, and I believe readers can get substantial help through this blog.

1. What is the advantage of LSTM over RNN? 🤔

RNN can remember fewer words related to the context than LSTM, so RNN is often called a short-term neural network, while LSTM is a Long short-term neural network, where long means that it can remember more context than RNN which represents the short-term neural network. So you don’t have to be confused by the name(Long Short-term). 🤗

2. Differences in terminology between RNN and LSTM for picture representation. 🧐

The RNN and LSTM pictures are from here.
在这里插入图片描述
在这里插入图片描述

  1. The output in RNN is ot, while the output of LSTM is ht.
  2. The contextual information(Memory) in RNN is stored in ht(above), while the contextual information(Memory) in LSTM is stored in ct.

3. The composition and intuitive understanding of LSTM. 🤠

When I first saw the architecture diagram of LSTM, I noticed a schematic of sigmoid and multiplication together.请添加图片描述
The sigmoid takes values from 0 to 1, which means that certain numbers are multiplied by 0 or 1, which means that the significance of each of these structures is to decide whether to use the data from the previous time step.This structure is named “gate”.

3.1. LSTM: Forget gate.请添加图片描述

When the value of “ft” is 1, it means I want to use the data remembered in the previous time step, and when the value is 0, it means I want to forget it.

3.2. LSTM: Input gate and Cell State.请添加图片描述

When the value of “it” is 1, it means that I want to use the data entered at the current time (“C wave t”), which is calculated by “tanh”, “Wc” and “bc” based on the data entered at the current time “xt”.请添加图片描述
⚠️: In summary, the forgetting gate determines whether the information remembered at the previous time step is useful, and the inputting gate determines whether the information to be remembered at the current time step is important.

3.3. LSTM: Output.

请添加图片描述
In summary, the output “ht” of the LSTM is the element-wise product of the “tanh” operation of “ct” and the “output gate”.

3.4. Summary

请添加图片描述
The forgetting gate, input gate, and output gate require the sigmoid and the input x(t) at the current time and the context information h(t-1) from the previous time step. To input the c-wave into the cell requires tanh and x(t), h(t-1). The c(t) and h(t) to be passed to the next time step are easy to understand intuitively based on the above diagram. 😇

请添加图片描述

3.5. Understanding the role of “gates” intuitively.请添加图片描述

3.6. Why can LSTM mitigate gradient vanishing?

请添加图片描述
Because when we solve for the gradient, we avoid the appearance of the kth power of “Whh”, and then because there are three “gates”, we need to expand three equations when solving for the gradient, and the three gates constrain each other so that the probability of a large or small value is much smaller.

4. How to implement LSTM with PyTorch? 😎

请添加图片描述
The PyTorch documentation on LSTM is from here.

import torch
from torch import nn

lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
x = torch.randn(10, 3, 100)
out, (h, c) = lstm(x)
print(out[-1])
print(h[-1])
print(out[-1].shape)
print(h[-1].shape)
print(out[-1]==h[-1])
print(f"out.shape:{out.shape}\nh.shape:{h.shape}\nc.shape:{c.shape}")

⚠️: The out of the LSTM is the value of the last time step h of all time steps h. 🤠
请添加图片描述

import torch
from torch import nn

lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
x = torch.randn(10, 3, 100)
out, (h, c) = lstm(x)
print(out[-1])
print(h[-1])
print(out[-1].shape)
print(h[-1].shape)
print(out[-1]==h[-1])
print(f"out.shape:{out.shape}\nh.shape:{h.shape}\nc.shape:{c.shape}")

print('one layer lstm')
cell = nn.LSTMCell(input_size=100, hidden_size=20)
h = torch.zeros(3, 20)
c = torch.zeros(3, 20)
for xt in x:
    h, c = cell(xt, [h, c])
print(f"h.shape:{h.shape}")
print(f"c.shape:{c.shape}")

print("two layer lstm")
cell1 = nn.LSTMCell(input_size=100, hidden_size=30)
cell2 = nn.LSTMCell(input_size=30, hidden_size=20)
h1 = torch.zeros(3, 30)
c1 = torch.zeros(3, 30)
h2 = torch.zeros(3, 20)
c2 = torch.zeros(3, 20)
for xt in x:
    h1, c1 = cell1(xt, [h1, c1])
    h2, c2 = cell2(h1, [h2, c2])
print(f"h2.shape:{h2.shape}")
print(f"c2.shape:{c2.shape}")

请添加图片描述

Finally 🤩

Thank you for the current age of knowledge sharing and the people willing to share it, thank you! The knowledge on this blog is what I’ve learned on this site, thanks for the support! 😇

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Chae_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值