深度强化元学习教程---优化器元学习2/2

优化器元网络推导

在梯度下降算法中,我们通过下面的公式来调整参数:
θ t = θ t − 1 − α t ∇ θ t − 1 L t \boldsymbol{\theta}_{t} = \boldsymbol{\theta}_{t-1} - \alpha_{t} \nabla_{\boldsymbol{\theta}_{t-1}} \mathcal{L}_{t} θt=θt1αtθt1Lt
根据上一节长短时记忆网络的讨论,我们更新Cell记忆时的公式为:
C t = f t ⊗ C t − 1 + i t ⊗ C ~ t \boldsymbol{C}_{t}=\boldsymbol{f}_{t} \otimes \boldsymbol{C}_{t-1} + \boldsymbol{i}_{t} \otimes \tilde{\boldsymbol{C}}_{t} Ct=ftCt1+itC~t
我们可以做如下假设:
f t = 1 C t − 1 = θ t − 1 i t = α t C t ~ = ∇ θ t − 1 L t \boldsymbol{f}_{t}=1 \\ \boldsymbol{C}_{t-1} = \boldsymbol{\theta}_{t-1} \\ \boldsymbol{i}_{t} = \alpha_{t} \\ \tilde{\boldsymbol{C}_{t}} = \nabla_{\boldsymbol{\theta}_{t-1}} \mathcal{L}_{t} ft=1Ct1=θt1it=αtCt~=θt1Lt
当我们做上述假设后,梯度下降算法就可以视为长短时记忆网络(LSTM)的Cell更新过程。如果我们把原来做图像分类的网络作为基础网络,用长短时记忆网络(LSTM)做优化器的元网络,LSTM网络中的Cell的状态值即为基础网络的参数,这样就构成了优化器元网络。
我们首先列出长短时记忆网络的公式:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) C ~ t = t a n h ( W C ⋅ [ h t − 1 , x t ] + b C ) C t = f t ⊗ C t − 1 + i t ⊗ C ~ t o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) h t = o t ⊗ t a n h ( C t ) \boldsymbol{f}_{t}=\sigma( W_{f} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{f} ) \\ \boldsymbol{i}_{t}=\sigma( W_{i} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{i} ) \\ \tilde{\boldsymbol{C}}_{t}=tanh( W_{C} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{C} ) \\ \boldsymbol{C}_{t}=\boldsymbol{f}_{t} \otimes \boldsymbol{C}_{t-1} + \boldsymbol{i}_{t} \otimes \tilde{\boldsymbol{C}}_{t} \\ \boldsymbol{o}_{t}=\sigma( W_{o} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{o} ) \\ \boldsymbol{h}_{t}=\boldsymbol{o}_{t} \otimes tanh(\boldsymbol{C}_{t}) ft=σ(Wf[ht1,xt]+bf)it=σ(Wi[ht1,xt]+bi)C~t=tanh(WC[ht1,xt]+bC)Ct=ftCt1+itC~tot=σ(Wo[ht1,xt]+bo)ht=ottanh(Ct)
我们首先来看遗忘门的应用,我们在训练深度学习网络中,最难处理的一种情况是我们进入了一个平坦的区域,这时梯度基本为零,对权值的调整将非常非常小,无法进行学习。如果我们利用元网络中的遗忘门,我们可以通过减小参数值,忘记一些之前的记忆来实现从这个平坦的区域内跳出来。此时遗忘门的输入值为:前一时刻的参数值、当前时刻基础网络的代价函数值、当前时刻基础网络的代价函数对网络参数的微分值、前一时刻遗忘门的输出,如下所示:
f t = σ ( W f ⋅ [ θ t − 1 , L t , ∇ θ t − 1 , f t − 1 ] + b f ) \boldsymbol{f}_{t}=\sigma( W_{f} \cdot [\boldsymbol{\theta}_{t-1}, \mathcal{L}_{t},\nabla_{\boldsymbol{\theta}_{t-1}}, \boldsymbol{f}_{t-1}] + \boldsymbol{b}_{f} ) ft=σ(Wf[θt1,Lt,θt1,ft1]+bf)
我们用输入门来控制需要更新哪些参数:
i t = σ ( W i ⋅ [ θ t − 1 , L t , ∇ θ t − 1 , i t − 1 ] + b i ) \boldsymbol{i}_{t}=\sigma( W_{i} \cdot [\boldsymbol{\theta}_{t-1}, \mathcal{L}_{t},\nabla_{\boldsymbol{\theta}_{t-1}}, \boldsymbol{i}_{t-1}] + \boldsymbol{b}_{i} ) it=σ(Wi[θt1,Lt,θt1,it1]+bi)

优化器学习算法

优化器元学习
利用随机值初始化元网络参数 ϕ 0 \boldsymbol{\phi}_{0} ϕ0
For d=1…N iterations
\quad 从数据集 D D D中随机采样 D t r a i n D^{train} Dtrain D t e s t D^{test} Dtest
\quad 将元网络中Cell初始状态 C 0 \boldsymbol{C}_{0} C0赋给基础网络参数 θ 0 \boldsymbol{\theta}_{0} θ0
\quad For t = 1... T t=1...T t=1...T iterations
\quad \quad D t r a i n D^{train} Dtrain中随机抽样一个批次 X t , Y t X_{t}, Y_{t} Xt,Yt
\quad \quad 在基础网络上计算代价函数值: L t ( Y t ∣ X t ; θ t ) \mathcal{L}_{t}(Y_{t} \vert X_{t}; \boldsymbol{\theta}_{t}) Lt(YtXt;θt)
\quad \quad 遗忘门: f t = σ ( W f ⋅ [ θ t − 1 , L t , ∇ θ t − 1 , f t − 1 ] + b f ) \boldsymbol{f}_{t}=\sigma( W_{f} \cdot [\boldsymbol{\theta}_{t-1}, \mathcal{L}_{t},\nabla_{\boldsymbol{\theta}_{t-1}}, \boldsymbol{f}_{t-1}] + \boldsymbol{b}_{f} ) ft=σ(Wf[θt1,Lt,θt1,ft1]+bf)
\quad \quad 输入门: i t = σ ( W i ⋅ [ θ t − 1 , L t , ∇ θ t − 1 , i t − 1 ] + b i ) \boldsymbol{i}_{t}=\sigma( W_{i} \cdot [\boldsymbol{\theta}_{t-1}, \mathcal{L}_{t},\nabla_{\boldsymbol{\theta}_{t-1}}, \boldsymbol{i}_{t-1}] + \boldsymbol{b}_{i} ) it=σ(Wi[θt1,Lt,θt1,it1]+bi)
\quad \quad 输入信号预处理: C ~ t = t a n h ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{\boldsymbol{C}}_{t}=tanh( W_{C} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{C} ) C~t=tanh(WC[ht1,xt]+bC)
\quad \quad 更新Cell状态: C t = f t ⊗ C t − 1 + i t ⊗ C ~ t \boldsymbol{C}_{t}=\boldsymbol{f}_{t} \otimes \boldsymbol{C}_{t-1} + \boldsymbol{i}_{t} \otimes \tilde{\boldsymbol{C}}_{t} Ct=ftCt1+itC~t
\quad \quad 输出门: o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) \boldsymbol{o}_{t}=\sigma( W_{o} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{o} ) ot=σ(Wo[ht1,xt]+bo)
\quad \quad 隐藏层状态: h t = o t ⊗ t a n h ( C t ) \boldsymbol{h}_{t}=\boldsymbol{o}_{t} \otimes tanh(\boldsymbol{C}_{t}) ht=ottanh(Ct)
\quad \quad 将元网络的Cell状态 C t \boldsymbol{C}_{t} Ct赋引基础网络参数 θ ∣ t \boldsymbol{\theta|_{t}} θt
\quad EndFor
\quad 从测试集 D t e s t D^{test} Dtest中抽样出一个样本 X , Y X, Y X,Y
\quad 计算基础网络代价函数: L t e s t = L θ t ( Y ∣ X ; θ t ) \mathcal{L}^{test}=\mathcal{L}_{\boldsymbol{\theta}_{t}}(Y \vert X; \boldsymbol{\theta}_{t}) Ltest=Lθt(YX;θt)
\quad 利用 ∇ L t e s t θ t − 1 \nabla{\mathcal{L}^{test}}_{\boldsymbol{\theta}_{t-1}} Ltestθt1更新元网络参数 ϕ d \phi_{d} ϕd
EndFor
在本章中我们讲述了元学习的基本概念,同时以优化器元学习网络为例,详细讲解了元学习算法的基本数学原理。在下一章中,我们将讨论最先出现同时也是使用最广泛的一种元学习网络Siamese网络,并通过TensorFlow2.0来实现对MNIST手写数字数据集的处理。

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值