1、LSTM结构图(多层)
2、LSTM 单个cell结构
3、底层计算验证
f ( f o r g e t ) = σ ( W i f x + b i f + W h f h + b h f ) 遗忘门 i ( i n p u t ) = σ ( W i i x + b i i + W h i h + b h i ) g = tanh ( W i g x + b i g + W h g h + b h g ) 输入门 o ( o u t p u t ) = σ ( W i o x + b i o + W h o h + b h o ) 输出门 c ′ = f ∗ c + i ∗ g h ′ = o ∗ tanh ( c ′ ) \begin{array}{ll} f_{(forget)} = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \quad 遗忘门\\ i_{(input)} = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \quad\quad\quad 输入门\\ o_{(output)} = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \quad 输出门\\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} f(forget)=σ(Wifx+bif+Whfh+bhf)遗忘门i(input)=σ(Wiix+bii+Whih+bhi)g=tanh(Wigx+big+Whgh+bhg)输入门o(output)=σ(Wiox+bio+Whoh+bho)输出门c′=f∗c+i∗gh′=o∗tanh(c′)
# 验证经过一个cell的计算
import torch
import torch.nn as nn
# 1. 设置特征
feature_size = 4
batch_size = 1
hidden_size = 10
x = torch.randn(batch_size, feature_size)
# 2. 利用torch自带的lstmcell计算一个节点的ht,ct
lstm = nn.LSTMCell(input_size=feature_size, hidden_size=hidden_size, bias=False)
h0 = torch.zeros(size=(batch_size, hidden_size))
c0 = torch.zeros(size=(batch_size, hidden_size))
ht, ct = lstm(x, (h0, c0))
print(f'调用LSTMCell模块计算{ht}')
print(f'调用LSTMCell模块计算{ct}')
# 3. 手动计算一个lstmcell输出ho,co
# 理论上lstm应该包含4个本次输入x权重矩阵(wii,wif,wig,wio)和4个上次输出权重矩阵(whi,whf,whg,who)共8个矩阵,但torch里面把4个进行了合并,简化计算
wih = lstm.weight_ih # shape=(10*4,4)
whh = lstm.weight_hh # shape=(10*4,10)
# 3.1 将上一步h与这一步x进行合并,后拆分成各个门的输入
ht_1 = torch.mm(input=h0, mat2=torch.t(whh))
xt = torch.mm(input=x, mat2=torch.t(wih))
hx = torch.add(ht_1, xt).reshape(-1, hidden_size)
i, f, g, o = hx[0], hx[1], hx[2], hx[3]
# 3.2 忘记门计算
c1 = torch.multiply(input=c0, other=torch.sigmoid(f))
# 3.2 输入门计算
c2 = torch.add(input=c1, other=torch.multiply(input=torch.sigmoid(i), other=torch.tanh(g)))
# 3.4 输出门计算
co = c2
ho = torch.multiply(input=torch.tanh(c2), other=torch.sigmoid(o))
print(f'手动根据结构图计算{ho}')
print(f'手动根据结构图计算{co}')
'''
调用LSTMCell模块计算tensor([[-0.0115, -0.0040, 0.0376, -0.0131, 0.0128, -0.0104, 0.0382, -0.0359,
-0.0498, 0.0463]], grad_fn=<MulBackward0>)
调用LSTMCell模块计算tensor([[-0.0244, -0.0077, 0.0661, -0.0289, 0.0255, -0.0201, 0.0829, -0.0789,
-0.1051, 0.0888]], grad_fn=<AddBackward0>)
手动根据结构图计算tensor([[-0.0115, -0.0040, 0.0376, -0.0131, 0.0128, -0.0104, 0.0382, -0.0359,
-0.0498, 0.0463]], grad_fn=<MulBackward0>)
手动根据结构图计算tensor([[-0.0244, -0.0077, 0.0661, -0.0289, 0.0255, -0.0201, 0.0829, -0.0789,
-0.1051, 0.0888]], grad_fn=<AddBackward0>)
'''