LSTM 在tensorflow.keras中的实现
从lstm前向传播计算开始吧
输 入 : x t , a t − 1 , c t − 1 权 重 : W u , b u , W f , b f , W o , b o , W c , b c 更 新 门 : Γ u = σ ( W u [ a t − 1 , x t ] + b u ) 遗 忘 门 : Γ f = σ ( W f [ a t − 1 , x t ] + b f ) 输 出 门 : Γ o = σ ( W o [ a t − 1 , x t ] + b o ) c t ~ = t a n h ( W c [ a t − 1 , x t ] + b c ) 细 胞 状 态 : c t = Γ u c t ~ + Γ f c t − 1 隐 藏 状 态 : a t = Γ o t a n h ( c t ) 输 出 : y t ^ = s o f t m a x ( a t ) 输入:x_t,a_{t-1},c_{t-1}\\ 权重:W_u,b_u,W_f,b_f,W_o,b_o,W_c,b_c\\ 更新门:\Gamma_u=\sigma(W_u[a_{t-1},x_t]+b_u)\\ 遗忘门:\Gamma_f=\sigma(W_f[a_{t-1},x_t]+b_f)\\ 输出门:\Gamma_o=\sigma(W_o[a_{t-1},x_t]+b_o)\\ \tilde{c_t}=tanh(W_c[a_{t-1},x_t]+b_c)\\ 细胞状态:c_t=\Gamma_u\tilde{c_t}+\Gamma_fc_{t-1}\\ 隐藏状态:a_t=\Gamma_otanh(c_t)\\ 输出:\hat{y_t}=softmax(a_t)\\ 输入:xt,at−1,ct−1权重:Wu,bu,Wf,bf,Wo,bo,Wc,bc更新门:Γu=σ(Wu[at−1,xt]+bu)遗忘门:Γf=σ(Wf[at−1,xt]+bf)输出门:Γo=σ(Wo[at−1,xt]+bo)ct~=tanh(Wc[at−1,xt]+bc)细胞状态:ct=Γuct~+Γfct−1隐藏状态:at=Γotanh(ct)输出:yt^=softmax(at)
如何计算参数数量
显而易见,参数数量应该为 ( ( 输 入 特 征 数 + 隐 藏 单 元 数 ) ∗ 隐 藏 单 元 数 + 隐 藏 单 元 数 ) ∗ 4 ((输入特征数+隐藏单元数)*隐藏单元数+隐藏单元数)*4 ((输入特征数+隐藏单元数)∗隐藏单元数+隐藏单元数)∗4
在keras中试一试
测试代码如下
from tensorflow import keras
input=keras.Input([32,64])
lstm=keras.layers.LSTM(128)
lstm(input)
print(*[(i.name,i.shape) for i in lstm.weights],sep='\n')
print(lstm.count_params())
结果为
看到这个你可能会感到奇怪,这一层不是应该有 W u , b u , W f , b f , W o , b o , W c , b c W_u,b_u,W_f,b_f,W_o,b_o,W_c,b_c Wu,bu,Wf,bf,Wo,bo,Wc,bc 8个权重吗,为什么只剩下3个了。
分析原因
以之前提到的lstm参数数量计算公式可以得到结果为 ( ( 64 + 128 ) ∗ 128 + 128 ) ∗ 4 = 98816 ((64+128)*128+128)*4=98816 ((64+128)∗128+128)∗4=98816,可以看出weights确实保存了前向传播所需要的参数,只不过使用了另外一种格式。
观察到weights的形状中有一个512,可以推得这个值是由 128 ∗ 4 128*4 128∗4 而来。
由此猜测:kernel值为4个W对应x的值
recurrent_kernel为4个W对应a的值
bias则为4个b叠起来
查看源码
keras.layers.LSTM继承于keras.layers.LSTMCell,可以在这里找到LSTM的关键计算和所有权重的声明。
权重初始化
self.kernel = self.add_weight(
shape=(input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint,
caching_device=default_caching_device)
if self.use_bias:
if self.unit_forget_bias:#将遗忘门偏置量初始化为1,其他则用0来初始化
def bias_initializer(_, *args, **kwargs):
return K.concatenate([
self.bias_initializer((self.units,), *args, **kwargs),
initializers.Ones()((self.units,), *args, **kwargs),
self.bias_initializer((self.units * 2,), *args, **kwargs),
])
else:
bias_initializer = self.bias_initializer
self.bias = self.add_weight(
shape=(self.units * 4,),
name='bias',
initializer=bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
caching_device=default_caching_device)
else:
self.bias = None
在这里可以证明之前的猜测,而权重的具体排列顺序则需要到call函数中寻找。可以找到lstm的两种实现:
实现1
k_i, k_f, k_c, k_o = array_ops.split(
self.kernel, num_or_size_splits=4, axis=1)# 分别取出各个操作的权重
x_i = K.dot(inputs_i, k_i)
x_f = K.dot(inputs_f, k_f)
x_c = K.dot(inputs_c, k_c)
x_o = K.dot(inputs_o, k_o)
b_i, b_f, b_c, b_o = array_ops.split(
self.bias, num_or_size_splits=4, axis=0)
x_i = K.bias_add(x_i, b_i)
x_f = K.bias_add(x_f, b_f)
x_c = K.bias_add(x_c, b_c)
x_o = K.bias_add(x_o, b_o)
#这里省略了dropout的处理
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
x = (x_i, x_f, x_c, x_o)#各个门添加权重和偏置后的x
h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) # 处理门,参数为各个权重和偏置处理后的x及上一步隐藏状态,细胞状态
h = o * self.activation(c)#新的隐藏状态
def _compute_carry_and_output(self, x, h_tm1, c_tm1):
"""Computes carry and output using split kernels."""
x_i, x_f, x_c, x_o = x
h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 #跟dropout有关,不用管,当做一样的就行
i = self.recurrent_activation(
x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))#更新门值
f = self.recurrent_activation(x_f + K.dot(
h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))#遗忘门值
c = f * c_tm1 + i * self.activation(x_c + K.dot(
h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))#新的细胞状态
o = self.recurrent_activation(
x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))#输出门值
return c, o
这种实现方式公式如下
输 入 : x t , a t − 1 , c t − 1 权 重 : W u x , W u a , b u , W f x , W f a , b f , W o x , W o a , b o , W c x , W c a , b c 处 理 x : 更 新 门 : u x t = W u x x t + b u 遗 忘 门 : f x t = W f x x t + b f 输 出 门 : o x t = W o x x t + b o 细 胞 状 态 中 间 计 算 值 : c x t = W c x x t + b c 处 理 隐 藏 状 态 , 计 算 门 值 和 新 状 态 : 更 新 门 : Γ u = σ ( W u a t a t − 1 + u x t ) 遗 忘 门 : Γ u = σ ( W f a a t − 1 + f x t ) 输 出 门 : Γ u = σ ( W o a a t − 1 + o x t ) 细 胞 状 态 : c t = Γ u t a n h ( W c a a t − 1 + c x ) + Γ f c t − 1 隐 藏 状 态 : a t = Γ o t a n h ( c t ) 输入:x_t,a_{t-1},c_{t-1}\\ 权重:W_{ux},W_{ua},b_u,W_{fx},W_{fa},b_f,W_{ox},W_{oa},b_o,W_{cx},W_{ca},b_c\\ 处理x:\\ 更新门:ux_t=W_{ux} x_t+b_u\\ 遗忘门:fx_t=W_{fx} x_t+b_f\\ 输出门:ox_t=W_{ox}x_t+b_o\\ 细胞状态中间计算值:cx_t=W_{cx}x_t+b_c\\ 处理隐藏状态,计算门值和新状态:\\ 更新门:\Gamma_u=\sigma(W_{ua}t a_{t-1}+ux_t)\\ 遗忘门:\Gamma_u=\sigma(W_{fa} a_{t-1}+fx_t)\\ 输出门:\Gamma_u=\sigma(W_{oa} a_{t-1}+ox_t)\\ 细胞状态:c_t=\Gamma_utanh(W_{ca} a_{t-1}+cx)+\Gamma_fc_{t-1}\\ 隐藏状态:a_t=\Gamma_otanh(c_t)\\ 输入:xt,at−1,ct−1权重:Wux,Wua,bu,Wfx,Wfa,bf,Wox,Woa,bo,Wcx,Wca,bc处理x:更新门:uxt=Wuxxt+bu遗忘门:fxt=Wfxxt+bf输出门:oxt=Woxxt+bo细胞状态中间计算值:cxt=Wcxxt+bc处理隐藏状态,计算门值和新状态:更新门:Γu=σ(Wuatat−1+uxt)遗忘门:Γu=σ(Wfaat−1+fxt)输出门:Γu=σ(Woaat−1+oxt)细胞状态:ct=Γutanh(Wcaat−1+cx)+Γfct−1隐藏状态:at=Γotanh(ct)
基本上就是拆开了隐藏状态和x的处理,需要注意的是偏置在处理x时添加,和在处理隐藏状态是添加是等价的。
实现2
z = K.dot(inputs, self.kernel)
z += K.dot(h_tm1, self.recurrent_kernel)
z = K.bias_add(z, self.bias)
#同时对多个门添加权重和偏置
z = array_ops.split(z, num_or_size_splits=4, axis=1)#将计算结果分开,还原成多个门值
c, o = self._compute_carry_and_output_fused(z, c_tm1)#用门处理,参数是权重和偏置处理后的x,a和上一步细胞状态
h = o * self.activation(c)
def _compute_carry_and_output_fused(self, z, c_tm1):
"""Computes carry and output using fused kernels."""
z0, z1, z2, z3 = z
i = self.recurrent_activation(z0)#更新门
f = self.recurrent_activation(z1)#遗忘门
c = f * c_tm1 + i * self.activation(z2)#新的细胞状态
o = self.recurrent_activation(z3)#输出门
return c, o
这种实现方式公式如下
输 入 : x t , a t − 1 , c t − 1 批 处 理 数 为 n , 输 入 特 征 数 为 c , 隐 藏 单 元 数 为 m 权 重 : W u x , W u a , b u , W f x , W f a , b f , W o x , W o a , b o , W c x , W c a , b c 对 x 添 加 权 重 : z x = x t ⏞ c } n [ ⋮ ⋮ ⋮ ⋮ w u x w f x w c x w o x ⋮ ⋮ ⋮ ⋮ ] ⏞ 4 ∗ m } c = > n × m ∗ 4 对 a 添 加 权 重 : z a = a t − 1 ⏞ m } n [ ⋮ ⋮ ⋮ ⋮ w u a w f a w c a w o a ⋮ ⋮ ⋮ ⋮ ] ⏞ 4 ∗ m } m = > n × m ∗ 4 相 加 并 添 加 偏 置 : z = z a + z x + [ ⋮ ⋮ ⋮ ⋮ b u b f b c b o ⋮ ⋮ ⋮ ⋮ ] = [ ⋮ ⋮ ⋮ ⋮ z 0 z 1 z 2 z 3 ⋮ ⋮ ⋮ ⋮ ] = > n × m ∗ 4 更 新 门 : Γ u = σ ( z 0 ) : n × m 遗 忘 门 : Γ u = σ ( z 1 ) : n × m 输 出 门 : Γ u = σ ( z 2 ) : n × m 细 胞 状 态 : c t = Γ u t a n h ( z 3 ) + Γ f c t − 1 : n × m ∗ 4 隐 藏 状 态 : a t = Γ o t a n h ( c t ) : n × m ∗ 4 输入:x_t,a_{t-1},c_{t-1}\\ 批处理数为n,输入特征数为c,隐藏单元数为m\\ 权重:W_{ux},W_{ua},b_u,W_{fx},W_{fa},b_f,W_{ox},W_{oa},b_o,W_{cx},W_{ca},b_c\\ 对x添加权重:zx=\overbrace{x_t}^c\}n \left. \overbrace{ \begin{bmatrix} \vdots & \vdots & \vdots& \vdots \\ w_{ux} & w_{fx} &w_{cx}& w_{ox} \\ \vdots & \vdots & \vdots & \vdots & \\ \end{bmatrix}}^{4*m}\right \} c =>n\times m*4 \\ 对a添加权重:za=\overbrace{a_{t-1} }^m\}n \left. \overbrace{ \begin{bmatrix} \vdots & \vdots & \vdots& \vdots \\ w_{ua} & w_{fa} &w_{ca}& w_{oa} \\ \vdots & \vdots & \vdots & \vdots & \\ \end{bmatrix}}^{4*m}\right \}m =>n\times m*4\\ 相加并添加偏置:z=za+zx+\begin{bmatrix} \vdots & \vdots & \vdots& \vdots \\ b_{u} & b_{f} &b_{c}& b_{o} \\ \vdots & \vdots & \vdots & \vdots & \\ \end{bmatrix}=\begin{bmatrix} \vdots & \vdots & \vdots& \vdots \\ z_{0} & z_{1} &z_{2}& z_{3} \\ \vdots & \vdots & \vdots & \vdots & \\ \end{bmatrix} =>n\times m*4\\ 更新门:\Gamma_u=\sigma(z_0) :n\times m\\ 遗忘门:\Gamma_u=\sigma(z_1) :n\times m\\ 输出门:\Gamma_u=\sigma(z_2) :n\times m\\ 细胞状态:c_t=\Gamma_utanh(z_3)+\Gamma_fc_{t-1} :n\times m*4\\ 隐藏状态:a_t=\Gamma_otanh(c_t) :n\times m*4\\ 输入:xt,at−1,ct−1批处理数为n,输入特征数为c,隐藏单元数为m权重:Wux,Wua,bu,Wfx,Wfa,bf,Wox,Woa,bo,Wcx,Wca,bc对x添加权重:zx=xt c}n⎣⎢⎢⎡⋮wux⋮⋮wfx⋮⋮wcx⋮⋮wox⋮⎦⎥⎥⎤ 4∗m⎭⎪⎪⎪⎪⎪⎪⎬⎪⎪⎪⎪⎪⎪⎫c=>n×m∗4对a添加权重:za=at−1 m}n⎣⎢⎢⎡⋮wua⋮⋮wfa⋮⋮wca⋮⋮woa⋮⎦⎥⎥⎤ 4∗m⎭⎪⎪⎪⎪⎪⎪⎬⎪⎪⎪⎪⎪⎪⎫m=>n×m∗4相加并添加偏置:z=za+zx+⎣⎢⎢⎡⋮bu⋮⋮bf⋮⋮bc⋮⋮bo⋮⎦⎥⎥⎤=⎣⎢⎢⎡⋮z0⋮⋮z1⋮⋮z2⋮⋮z3⋮⎦⎥⎥⎤=>n×m∗4更新门:Γu=σ(z0):n×m遗忘门:Γu=σ(z1):n×m输出门:Γu=σ(z2):n×m细胞状态:ct=Γutanh(z3)+Γfct−1:n×m∗4隐藏状态:at=Γotanh(ct):n×m∗4
需要注意,代码里面的K.dot是矩阵乘法
总结
lstm对门和 c t ~ \tilde{c_t} ct~的计算非常相似,第二种实现就是利用了这个特性使得计算更加密集,能得到一定的加速,这是一种非常优雅的实现方式。明显,weights的格式方便了这个计算过程。