Variational Inference with Normalizing Flow

Rezende D., Mohamed S. Variational Inference with Normalizing Flow. ICML, 2015.

VAE的先验分布很重要, 但是后验分布也很重要, 我们常常假设 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)满足一个高斯分布, 这就大大限制了近似后验分布的逼近的准确性.
这番假设实在是过于强烈了.
本文提出的 normalizing flows的方法可以提高 q ϕ q_{\phi} qϕ的逼近能力.

主要内容

首先, 假设我们得到了 q 0 ( z 0 ∣ x ) q_{0}(z_0|x) q0(z0x)(通过重采样得到 z z z), 此时我们通过一个可逆函数 f f f, 得到
z 1 = f ( z 0 ) , z_1 = f(z_0), z1=f(z0),
z 1 z_1 z1的分布满足:
q ( z 1 ) = q ( z 0 ) ∣ d e t ∇ z f − 1 ∣ = q ( z 0 ) ∣ d e t ∇ f ∣ − 1 . q(z_1) = q(z_0) |\mathrm{det} \nabla_z f^{-1}| = q(z_0) |\mathrm{det} \nabla f|^{-1}. q(z1)=q(z0)detzf1=q(z0)detf1.
以此类推可得:
z K = f K ∘ ⋯ ∘ f 2 ∘ f 1 ( z 0 ) , ln ⁡ q K ( z K ) = ln ⁡ q 0 ( z 0 ) − ∑ k = 1 K ln ⁡ ∣ d e t ∇ z k − 1 f k ∣ . z_K = f_K \circ \cdots \circ f_2 \circ f_1(z_0), \\ \ln q_K(z_K) = \ln q_0(z_0) - \sum_{k=1}^K \ln |\mathrm{det} \nabla_{z_{k-1}} f_k|. zK=fKf2f1(z0),lnqK(zK)=lnq0(z0)k=1Klndetzk1fk.

也就是说, 只要我们能计算出Jacobian行列式, 那么后验分布的近似能力就大大提高了.

此时ELBO的负数形式为:
F ( x ) = E q ϕ ( z ∣ x ) [ ln ⁡ q ϕ ( z ∣ x ) − ln ⁡ p θ ( x , z ) ] = E q 0 ( z 0 ) [ ln ⁡ q K ( z K ) − ln ⁡ p θ ( x , z K ) ] = E q 0 ( z 0 ) [ ln ⁡ q 0 ( z 0 ) ] − E q 0 ( z 0 ) [ ∑ k = 1 K ln ⁡ ∣ d e t ∇ z k − 1 f k ∣ ] + E q 0 ( z 0 ) [ ln ⁡ p θ ( x , z K ) ] . \begin{array}{ll} \mathcal{F}(x) &= \mathbb{E}_{q_{\phi}(z|x)}[\ln q_{\phi}(z|x) - \ln p_{\theta}(x,z)] \\ &= \mathbb{E}_{q_{0}(z_0)}[\ln q_{K}(z_K) - \ln p_{\theta}(x,z_K)] \\ &= \mathbb{E}_{q_0(z_0)}[\ln q_0(z_0)] - \mathbb{E}_{q_0(z_0)}[\sum_{k=1}^K\ln |\mathrm{det} \nabla_{z_{k-1}} f_k|] \\ & + \mathbb{E}_{q_0(z_0)} [\ln p_{\theta}(x,z_K)]. \end{array} F(x)=Eqϕ(zx)[lnqϕ(zx)lnpθ(x,z)]=Eq0(z0)[lnqK(zK)lnpθ(x,zK)]=Eq0(z0)[lnq0(z0)]Eq0(z0)[k=1Klndetzk1fk]+Eq0(z0)[lnpθ(x,zK)].

注: 因为最后一项和 q K q_K qK无关, 可以由采样直接近似.

一些合适的可逆变换

f ( z ) = z + u h ( w T z + b ) , f(z) = z + u h(w^Tz + b), f(z)=z+uh(wTz+b),
其中 h h h是一个非线性的激活函数. 则
ψ ( z ) = h ′ ( w T z + b ) w ∣ d e t ∇ z f ∣ = ∣ 1 + u T ψ ( z ) ∣ . \psi(z) = h'(w^Tz+b)w \\ |\mathrm{det} \nabla_z f| = |1 + u^T \psi(z)|. ψ(z)=h(wTz+b)wdetzf=1+uTψ(z).

f ( z ) = z + β ⋅ h ( α , γ ) ( z − z 0 ) , γ = ∣ z − z 0 ∣ , h ( α , γ ) = 1 / ( α + γ ) . f(z) = z + \beta \cdot h(\alpha, \gamma)(z-z_0), \\ \gamma = |z - z_0|, h(\alpha, \gamma) = 1/ (\alpha + \gamma). f(z)=z+βh(α,γ)(zz0),γ=zz0,h(α,γ)=1/(α+γ).
此时
∣ d e t ∇ z f ∣ = [ 1 + β h ( α , γ ) ] d − 1 [ 1 + β h ( α , γ ) + β h ′ ( α , γ ) γ ] . |\mathrm{det} \nabla_z f| = [1 + \beta h(\alpha, \gamma)]^{d-1}[1+\beta h(\alpha, \gamma) + \beta h'(\alpha, \gamma) \gamma]. detzf=[1+βh(α,γ)]d1[1+βh(α,γ)+βh(α,γ)γ].
其中 d d d z z z的维度.

代码

非官方代码

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值