0. Statement 🏫
Today I intend to move from an intuitive understanding of LSTM to its implementation with PyTorch, and I believe readers can get substantial help through this blog.
1. What is the advantage of LSTM over RNN? 🤔
RNN can remember fewer words related to the context than LSTM, so RNN is often called a short-term neural network, while LSTM is a Long short-term neural network, where long means that it can remember more context than RNN which represents the short-term neural network. So you don’t have to be confused by the name(Long Short-term). 🤗
2. Differences in terminology between RNN and LSTM for picture representation. 🧐
The RNN and LSTM pictures are from here.
- The output in RNN is ot, while the output of LSTM is ht.
- The contextual information(Memory) in RNN is stored in ht(above), while the contextual information(Memory) in LSTM is stored in ct.
3. The composition and intuitive understanding of LSTM. 🤠
When I first saw the architecture diagram of LSTM, I noticed a schematic of sigmoid and multiplication together.
The sigmoid takes values from 0 to 1, which means that certain numbers are multiplied by 0 or 1, which means that the significance of each of these structures is to decide whether to use the data from the previous time step.This structure is named “gate”.
3.1. LSTM: Forget gate.
When the value of “ft” is 1, it means I want to use the data remembered in the previous time step, and when the value is 0, it means I want to forget it.
3.2. LSTM: Input gate and Cell State.
When the value of “it” is 1, it means that I want to use the data entered at the current time (“C wave t”), which is calculated by “tanh”, “Wc” and “bc” based on the data entered at the current time “xt”.
⚠️: In summary, the forgetting gate determines whether the information remembered at the previous time step is useful, and the inputting gate determines whether the information to be remembered at the current time step is important.
3.3. LSTM: Output.
In summary, the output “ht” of the LSTM is the element-wise product of the “tanh” operation of “ct” and the “output gate”.
3.4. Summary
The forgetting gate, input gate, and output gate require the sigmoid and the input x(t) at the current time and the context information h(t-1) from the previous time step. To input the c-wave into the cell requires tanh and x(t), h(t-1). The c(t) and h(t) to be passed to the next time step are easy to understand intuitively based on the above diagram. 😇
3.5. Understanding the role of “gates” intuitively.
3.6. Why can LSTM mitigate gradient vanishing?
Because when we solve for the gradient, we avoid the appearance of the kth power of “Whh”, and then because there are three “gates”, we need to expand three equations when solving for the gradient, and the three gates constrain each other so that the probability of a large or small value is much smaller.
4. How to implement LSTM with PyTorch? 😎
The PyTorch documentation on LSTM is from here.
import torch
from torch import nn
lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
x = torch.randn(10, 3, 100)
out, (h, c) = lstm(x)
print(out[-1])
print(h[-1])
print(out[-1].shape)
print(h[-1].shape)
print(out[-1]==h[-1])
print(f"out.shape:{out.shape}\nh.shape:{h.shape}\nc.shape:{c.shape}")
⚠️: The out of the LSTM is the value of the last time step h of all time steps h. 🤠
import torch
from torch import nn
lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
x = torch.randn(10, 3, 100)
out, (h, c) = lstm(x)
print(out[-1])
print(h[-1])
print(out[-1].shape)
print(h[-1].shape)
print(out[-1]==h[-1])
print(f"out.shape:{out.shape}\nh.shape:{h.shape}\nc.shape:{c.shape}")
print('one layer lstm')
cell = nn.LSTMCell(input_size=100, hidden_size=20)
h = torch.zeros(3, 20)
c = torch.zeros(3, 20)
for xt in x:
h, c = cell(xt, [h, c])
print(f"h.shape:{h.shape}")
print(f"c.shape:{c.shape}")
print("two layer lstm")
cell1 = nn.LSTMCell(input_size=100, hidden_size=30)
cell2 = nn.LSTMCell(input_size=30, hidden_size=20)
h1 = torch.zeros(3, 30)
c1 = torch.zeros(3, 30)
h2 = torch.zeros(3, 20)
c2 = torch.zeros(3, 20)
for xt in x:
h1, c1 = cell1(xt, [h1, c1])
h2, c2 = cell2(h1, [h2, c2])
print(f"h2.shape:{h2.shape}")
print(f"c2.shape:{c2.shape}")
Finally 🤩
Thank you for the current age of knowledge sharing and the people willing to share it, thank you! The knowledge on this blog is what I’ve learned on this site, thanks for the support! 😇