LSTM 在tensorflow.keras中的实现

从lstm前向传播计算开始吧

输 入 : x t , a t − 1 , c t − 1 权 重 : W u , b u , W f , b f , W o , b o , W c , b c 更 新 门 : Γ u = σ ( W u [ a t − 1 , x t ] + b u ) 遗 忘 门 : Γ f = σ ( W f [ a t − 1 , x t ] + b f ) 输 出 门 : Γ o = σ ( W o [ a t − 1 , x t ] + b o ) c t ~ = t a n h ( W c [ a t − 1 , x t ] + b c ) 细 胞 状 态 : c t = Γ u c t ~ + Γ f c t − 1 隐 藏 状 态 : a t = Γ o t a n h ( c t ) 输 出 : y t ^ = s o f t m a x ( a t ) 输入:x_t,a_{t-1},c_{t-1}\\ 权重:W_u,b_u,W_f,b_f,W_o,b_o,W_c,b_c\\ 更新门:\Gamma_u=\sigma(W_u[a_{t-1},x_t]+b_u)\\ 遗忘门:\Gamma_f=\sigma(W_f[a_{t-1},x_t]+b_f)\\ 输出门:\Gamma_o=\sigma(W_o[a_{t-1},x_t]+b_o)\\ \tilde{c_t}=tanh(W_c[a_{t-1},x_t]+b_c)\\ 细胞状态:c_t=\Gamma_u\tilde{c_t}+\Gamma_fc_{t-1}\\ 隐藏状态:a_t=\Gamma_otanh(c_t)\\ 输出:\hat{y_t}=softmax(a_t)\\ xt,at1,ct1Wu,bu,Wf,bf,Wo,bo,Wc,bcΓu=σ(Wu[at1,xt]+bu)Γf=σ(Wf[at1,xt]+bf)Γo=σ(Wo[at1,xt]+bo)ct~=tanh(Wc[at1,xt]+bc)ct=Γuct~+Γfct1at=Γotanh(ct)yt^=softmax(at)

如何计算参数数量

显而易见,参数数量应该为 ( ( 输 入 特 征 数 + 隐 藏 单 元 数 ) ∗ 隐 藏 单 元 数 + 隐 藏 单 元 数 ) ∗ 4 ((输入特征数+隐藏单元数)*隐藏单元数+隐藏单元数)*4 ((+)+)4

在keras中试一试

测试代码如下

from tensorflow import keras
input=keras.Input([32,64])
lstm=keras.layers.LSTM(128)
lstm(input)
print(*[(i.name,i.shape) for i in lstm.weights],sep='\n')
print(lstm.count_params())

结果为
在这里插入图片描述

看到这个你可能会感到奇怪,这一层不是应该有 W u , b u , W f , b f , W o , b o , W c , b c W_u,b_u,W_f,b_f,W_o,b_o,W_c,b_c Wu,bu,Wf,bf,Wo,bo,Wc,bc 8个权重吗,为什么只剩下3个了。

分析原因

以之前提到的lstm参数数量计算公式可以得到结果为 ( ( 64 + 128 ) ∗ 128 + 128 ) ∗ 4 = 98816 ((64+128)*128+128)*4=98816 ((64+128)128+128)4=98816,可以看出weights确实保存了前向传播所需要的参数,只不过使用了另外一种格式。
观察到weights的形状中有一个512,可以推得这个值是由 128 ∗ 4 128*4 1284 而来。
由此猜测:

kernel值为4个W对应x的值
recurrent_kernel为4个W对应a的值
bias则为4个b叠起来

查看源码

keras.layers.LSTM继承于keras.layers.LSTMCell,可以在这里找到LSTM的关键计算和所有权重的声明。

权重初始化

	self.kernel = self.add_weight(
        shape=(input_dim, self.units * 4),
        name='kernel',
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        caching_device=default_caching_device)
    self.recurrent_kernel = self.add_weight(
        shape=(self.units, self.units * 4),
        name='recurrent_kernel',
        initializer=self.recurrent_initializer,
        regularizer=self.recurrent_regularizer,
        constraint=self.recurrent_constraint,
        caching_device=default_caching_device)

    if self.use_bias:
      if self.unit_forget_bias:#将遗忘门偏置量初始化为1,其他则用0来初始化
        def bias_initializer(_, *args, **kwargs):
          return K.concatenate([
              self.bias_initializer((self.units,), *args, **kwargs),
              initializers.Ones()((self.units,), *args, **kwargs),
              self.bias_initializer((self.units * 2,), *args, **kwargs),
          ])
      else:
        bias_initializer = self.bias_initializer
      self.bias = self.add_weight(
          shape=(self.units * 4,),
          name='bias',
          initializer=bias_initializer,
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint,
          caching_device=default_caching_device)
    else:
      self.bias = None

在这里可以证明之前的猜测,而权重的具体排列顺序则需要到call函数中寻找。可以找到lstm的两种实现:

实现1

	  k_i, k_f, k_c, k_o = array_ops.split(
          self.kernel, num_or_size_splits=4, axis=1)# 分别取出各个操作的权重
      x_i = K.dot(inputs_i, k_i)
      x_f = K.dot(inputs_f, k_f)
      x_c = K.dot(inputs_c, k_c)
      x_o = K.dot(inputs_o, k_o)
      b_i, b_f, b_c, b_o = array_ops.split(
          self.bias, num_or_size_splits=4, axis=0)
      x_i = K.bias_add(x_i, b_i)
      x_f = K.bias_add(x_f, b_f)
      x_c = K.bias_add(x_c, b_c)
      x_o = K.bias_add(x_o, b_o)
	  #这里省略了dropout的处理		
      h_tm1_i = h_tm1 
      h_tm1_f = h_tm1
      h_tm1_c = h_tm1
      h_tm1_o = h_tm1
      x = (x_i, x_f, x_c, x_o)#各个门添加权重和偏置后的x
      h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
      c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) # 处理门,参数为各个权重和偏置处理后的x及上一步隐藏状态,细胞状态
      h = o * self.activation(c)#新的隐藏状态

  def _compute_carry_and_output(self, x, h_tm1, c_tm1):
    """Computes carry and output using split kernels."""
    x_i, x_f, x_c, x_o = x
    h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 #跟dropout有关,不用管,当做一样的就行
    i = self.recurrent_activation(
        x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))#更新门值
    f = self.recurrent_activation(x_f + K.dot(
        h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))#遗忘门值
    c = f * c_tm1 + i * self.activation(x_c + K.dot(
        h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))#新的细胞状态
    o = self.recurrent_activation(
        x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))#输出门值
    return c, o

这种实现方式公式如下

输 入 : x t , a t − 1 , c t − 1 权 重 : W u x , W u a , b u , W f x , W f a , b f , W o x , W o a , b o , W c x , W c a , b c 处 理 x : 更 新 门 : u x t = W u x x t + b u 遗 忘 门 : f x t = W f x x t + b f 输 出 门 : o x t = W o x x t + b o 细 胞 状 态 中 间 计 算 值 : c x t = W c x x t + b c 处 理 隐 藏 状 态 , 计 算 门 值 和 新 状 态 : 更 新 门 : Γ u = σ ( W u a t a t − 1 + u x t ) 遗 忘 门 : Γ u = σ ( W f a a t − 1 + f x t ) 输 出 门 : Γ u = σ ( W o a a t − 1 + o x t ) 细 胞 状 态 : c t = Γ u t a n h ( W c a a t − 1 + c x ) + Γ f c t − 1 隐 藏 状 态 : a t = Γ o t a n h ( c t ) 输入:x_t,a_{t-1},c_{t-1}\\ 权重:W_{ux},W_{ua},b_u,W_{fx},W_{fa},b_f,W_{ox},W_{oa},b_o,W_{cx},W_{ca},b_c\\ 处理x:\\ 更新门:ux_t=W_{ux} x_t+b_u\\ 遗忘门:fx_t=W_{fx} x_t+b_f\\ 输出门:ox_t=W_{ox}x_t+b_o\\ 细胞状态中间计算值:cx_t=W_{cx}x_t+b_c\\ 处理隐藏状态,计算门值和新状态:\\ 更新门:\Gamma_u=\sigma(W_{ua}t a_{t-1}+ux_t)\\ 遗忘门:\Gamma_u=\sigma(W_{fa} a_{t-1}+fx_t)\\ 输出门:\Gamma_u=\sigma(W_{oa} a_{t-1}+ox_t)\\ 细胞状态:c_t=\Gamma_utanh(W_{ca} a_{t-1}+cx)+\Gamma_fc_{t-1}\\ 隐藏状态:a_t=\Gamma_otanh(c_t)\\ xt,at1,ct1Wux,Wua,bu,Wfx,Wfa,bf,Wox,Woa,bo,Wcx,Wca,bcxuxt=Wuxxt+bufxt=Wfxxt+bfoxt=Woxxt+bocxt=Wcxxt+bcΓu=σ(Wuatat1+uxt)Γu=σ(Wfaat1+fxt)Γu=σ(Woaat1+oxt)ct=Γutanh(Wcaat1+cx)+Γfct1at=Γotanh(ct)

基本上就是拆开了隐藏状态和x的处理,需要注意的是偏置在处理x时添加,和在处理隐藏状态是添加是等价的。

实现2

	z = K.dot(inputs, self.kernel)
	z += K.dot(h_tm1, self.recurrent_kernel)
	z = K.bias_add(z, self.bias)
	#同时对多个门添加权重和偏置
	z = array_ops.split(z, num_or_size_splits=4, axis=1)#将计算结果分开,还原成多个门值
	c, o = self._compute_carry_and_output_fused(z, c_tm1)#用门处理,参数是权重和偏置处理后的x,a和上一步细胞状态
	
	h = o * self.activation(c)

  def _compute_carry_and_output_fused(self, z, c_tm1):
    """Computes carry and output using fused kernels."""
    z0, z1, z2, z3 = z
    i = self.recurrent_activation(z0)#更新门
    f = self.recurrent_activation(z1)#遗忘门
    c = f * c_tm1 + i * self.activation(z2)#新的细胞状态
    o = self.recurrent_activation(z3)#输出门
    return c, o

这种实现方式公式如下

输 入 : x t , a t − 1 , c t − 1 批 处 理 数 为 n , 输 入 特 征 数 为 c , 隐 藏 单 元 数 为 m 权 重 : W u x , W u a , b u , W f x , W f a , b f , W o x , W o a , b o , W c x , W c a , b c 对 x 添 加 权 重 : z x = x t ⏞ c } n [ ⋮ ⋮ ⋮ ⋮ w u x w f x w c x w o x ⋮ ⋮ ⋮ ⋮ ] ⏞ 4 ∗ m } c = > n × m ∗ 4 对 a 添 加 权 重 : z a = a t − 1 ⏞ m } n [ ⋮ ⋮ ⋮ ⋮ w u a w f a w c a w o a ⋮ ⋮ ⋮ ⋮ ] ⏞ 4 ∗ m } m = > n × m ∗ 4 相 加 并 添 加 偏 置 : z = z a + z x + [ ⋮ ⋮ ⋮ ⋮ b u b f b c b o ⋮ ⋮ ⋮ ⋮ ] = [ ⋮ ⋮ ⋮ ⋮ z 0 z 1 z 2 z 3 ⋮ ⋮ ⋮ ⋮ ] = > n × m ∗ 4 更 新 门 : Γ u = σ ( z 0 ) : n × m 遗 忘 门 : Γ u = σ ( z 1 ) : n × m 输 出 门 : Γ u = σ ( z 2 ) : n × m 细 胞 状 态 : c t = Γ u t a n h ( z 3 ) + Γ f c t − 1 : n × m ∗ 4 隐 藏 状 态 : a t = Γ o t a n h ( c t ) : n × m ∗ 4 输入:x_t,a_{t-1},c_{t-1}\\ 批处理数为n,输入特征数为c,隐藏单元数为m\\ 权重:W_{ux},W_{ua},b_u,W_{fx},W_{fa},b_f,W_{ox},W_{oa},b_o,W_{cx},W_{ca},b_c\\ 对x添加权重:zx=\overbrace{x_t}^c\}n \left. \overbrace{ \begin{bmatrix} \vdots & \vdots & \vdots& \vdots \\ w_{ux} & w_{fx} &w_{cx}& w_{ox} \\ \vdots & \vdots & \vdots & \vdots & \\ \end{bmatrix}}^{4*m}\right \} c =>n\times m*4 \\ 对a添加权重:za=\overbrace{a_{t-1} }^m\}n \left. \overbrace{ \begin{bmatrix} \vdots & \vdots & \vdots& \vdots \\ w_{ua} & w_{fa} &w_{ca}& w_{oa} \\ \vdots & \vdots & \vdots & \vdots & \\ \end{bmatrix}}^{4*m}\right \}m =>n\times m*4\\ 相加并添加偏置:z=za+zx+\begin{bmatrix} \vdots & \vdots & \vdots& \vdots \\ b_{u} & b_{f} &b_{c}& b_{o} \\ \vdots & \vdots & \vdots & \vdots & \\ \end{bmatrix}=\begin{bmatrix} \vdots & \vdots & \vdots& \vdots \\ z_{0} & z_{1} &z_{2}& z_{3} \\ \vdots & \vdots & \vdots & \vdots & \\ \end{bmatrix} =>n\times m*4\\ 更新门:\Gamma_u=\sigma(z_0) :n\times m\\ 遗忘门:\Gamma_u=\sigma(z_1) :n\times m\\ 输出门:\Gamma_u=\sigma(z_2) :n\times m\\ 细胞状态:c_t=\Gamma_utanh(z_3)+\Gamma_fc_{t-1} :n\times m*4\\ 隐藏状态:a_t=\Gamma_otanh(c_t) :n\times m*4\\ xt,at1,ct1ncmWux,Wua,bu,Wfx,Wfa,bf,Wox,Woa,bo,Wcx,Wca,bcxzx=xt c}nwuxwfxwcxwox 4mc=>n×m4aza=at1 m}nwuawfawcawoa 4mm=>n×m4z=za+zx+bubfbcbo=z0z1z2z3=>n×m4Γu=σ(z0):n×mΓu=σ(z1):n×mΓu=σ(z2):n×mct=Γutanh(z3)+Γfct1:n×m4at=Γotanh(ct):n×m4

需要注意,代码里面的K.dot是矩阵乘法

总结

lstm对门和 c t ~ \tilde{c_t} ct~的计算非常相似,第二种实现就是利用了这个特性使得计算更加密集,能得到一定的加速,这是一种非常优雅的实现方式。明显,weights的格式方便了这个计算过程。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值