理论
LSTM的核心是细胞状态,用贯穿细胞的水平线表示
1.计算遗忘门
决定细胞状态需要舍弃哪部分无用信息
f t = σ g ( W f x t + U f h t − 1 + b f ) f_t = \sigma{_g} (W_f x_t+U_f h_{t-1}+b_f) ft=σg(Wfxt+Ufht−1+bf)
2.计算输入门
决定细胞状态需要添加哪些有用信息
i t = σ g ( W i x t + U i h t − 1 + b i ) i_t = \sigma{_g}(W_i x_t+U_i h_{t-1}+b_i) it=σg(Wixt+Uiht−1+bi)
3.计算候选细胞状态
c ~ t = σ c ( W c x t + U c h t − 1 + b c ) \widetilde{c}_t=\sigma{_c}(W_cx_t+U_ch_{t-1}+b_c) c t=σc(Wcxt+Ucht−1+bc)
4.更新细胞状态
c t = f t ∘ c t − 1 + i t ∘ c ~ t c_t=f_t \circ c_{t-1}+i_t \circ \widetilde{c}_t ct=ft∘ct−1+it∘c t
5.计算输出门
控制细胞状态中哪些信息被输出
o
t
=
σ
g
(
W
o
x
t
+
U
o
h
t
−
1
+
b
o
)
o_t=\sigma{_g}(W_ox_t+U_oh_{t-1}+b_o)
ot=σg(Woxt+Uoht−1+bo)
6.计算输出隐状态
h t = o t ∘ σ h ( c t ) h_t = o_t \circ \sigma{_h}(c_t) ht=ot∘σh(ct)
实践
从零实现LSTM
class My_LSTM(nn. Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
self.gates = nn.Linear(input_size + hidden_size, hidden_size * 4)
self.sigmoid = nn.Sigmoid()
self.tanh = nn. Tanh()
self.output = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, output_size)
)
for param in self.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
def forward(self, x):
batch_size = x.size(0)
seq_len = x.size(1)
h, c = (torch.zeros(batch_size, self.hidden_size).to(x.device) for _ in range(2))
y_list = []
for i in range(seq_len):
forget_gate, input_gate, output_gate, candidate_cell = \
self.gates(torch.cat([x[:, i, :], h], dim=-1)).chunk(4, -1)
forget_gate, input_gate, output_gate = (self.sigmoid(g)
for g in (forget_gate, input_gate, output_gate))
c = forget_gate * c + input_gate * self.tanh(candidate_cell)
h = output_gate * self.tanh(c)
y_list.append(self.output(h))
return torch.stack(y_list, dim=1), (h, c)
Pytorch实现LSTM
参数
输入
输出
lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=1,batch_first=True).to(device)