Auto-Encoding Variational Bayes

Kingma D P, Welling M. Auto-Encoding Variational Bayes[J]. arXiv: Machine Learning, 2013.

主要内容

自编码, 通过引入Encoder和Decoder来估计联合分布 p ( x , z ) p(x,z) p(x,z), 其中 z z z表示隐变量(我们也可以让 z z z为样本标签, 使得Encoder成为一个判别器).

在Decoder中我们建立联合分布 p θ ( x , z ) p_{\theta}(x,z) pθ(x,z)以估计 p ( x , z ) p(x,z) p(x,z), 在Encoder中建立一个后验分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)去估计 p θ ( z ∣ x ) p_{\theta}(z|x) pθ(zx), 然后极大似然:
log ⁡ p θ ( x ) = log ⁡ p θ ( x , z ) p θ ( z ∣ x ) = log ⁡ p θ ( x , z ) q ϕ ( z ∣ x ) q ϕ ( z ∣ x ) p θ ( z ∣ x ) = log ⁡ p θ ( x , z ) q ϕ ( z ∣ x ) + log ⁡ q ϕ ( z ∣ x ) p θ ( z ∣ x ) , \begin{array}{ll} \log p_{\theta}(x) &= \log \frac{p_{\theta}(x,z)}{p_{\theta}(z|x)} \\ & = \log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)} \\ & = \log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} + \log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)} \\ \end{array}, logpθ(x)=logpθ(zx)pθ(x,z)=logqϕ(zx)pθ(x,z)pθ(zx)qϕ(zx)=logqϕ(zx)pθ(x,z)+logpθ(zx)qϕ(zx),
上式俩边关于 z z z在分布 q ϕ ( z ) q_{\phi}(z) qϕ(z)下求期望可得:
log ⁡ p θ ( x ) = E q ϕ ( z ∣ x ) ( log ⁡ p θ ( x , z ) q ϕ ( z ∣ x ) + log ⁡ q ϕ ( z ∣ x ) p θ ( z ∣ x ) ) = E q ϕ ( z ∣ x ) ( log ⁡ p θ ( x , z ) q ϕ ( z ∣ x ) ) + D K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ) ) ≥ E q ϕ ( z ∣ x ) ( log ⁡ p θ ( x , z ) q ϕ ( z ∣ x ) ) . \begin{array}{ll} \log p_{\theta}(x) & = \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} + \log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)}) \\ &= \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} )+D_{KL}(q_{\phi}(z|x)\| p_{\theta}(z |x ))\\ & \ge \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} ) \end{array}. logpθ(x)=Eqϕ(zx)(logqϕ(zx)pθ(x,z)+logpθ(zx)qϕ(zx))=Eqϕ(zx)(logqϕ(zx)pθ(x,z))+DKL(qϕ(zx)pθ(zx))Eqϕ(zx)(logqϕ(zx)pθ(x,z)).

既然KL散度非负, 我们极大似然 log ⁡ p θ ( x ) \log p_{\theta}(x) logpθ(x)可以退而求其次, 最大化 E q ϕ ( z ∣ x ) ( log ⁡ p θ ( x , z ) q ϕ ( z ∣ x ) ) \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} ) Eqϕ(zx)(logqϕ(zx)pθ(x,z))(ELBO, 记为 L \mathcal{L} L).

又( p θ ( z ) p_{\theta}(z) pθ(z)为人为给定的先验分布)
L ( θ , ϕ ; x ) = − D K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ) ) + E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x ∣ z ) ] , \begin{array}{ll} \mathcal{L}(\theta, \phi; x) &= -D_{KL}(q_{\phi}(z|x)\|p_{\theta}(z))+\mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)], \end{array} L(θ,ϕ;x)=DKL(qϕ(zx)pθ(z))+Eqϕ(zx)[logpθ(xz)],
我们接下来通过对Encoder和Decoder的一些构造进一步扩展上面俩项.

Encoder (损失part1)

Encoder 将 x → z x\rightarrow z xz, 就相当于在 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)中进行采样, 但是如果是直接采样的话, 就没法利用梯度回传进行训练了, 这里需要一个重参化技巧.

我们假设 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)为高斯密度函数, 即 N ( μ , σ 2 I ) \mathcal{N}(\mu, \sigma^2 I) N(μ,σ2I).
注: 文中还提到了其他的一些可行假设.

我们构建一个神经网络 f f f, 其输入为样本 x x x, 输出为 ( μ , log ⁡ σ ) (\mu, \log \sigma) (μ,logσ)(输出 log ⁡ σ \log \sigma logσ是为了保证 σ \sigma σ为正), 则
z = μ + ϵ ⊙ σ , ϵ ∼ N ( 0 , I ) , z= \mu + \epsilon \odot \sigma, \epsilon \sim \mathcal{N}(0, I), z=μ+ϵσ,ϵN(0,I),
其中 ⊙ \odot 表示按元素相乘.
注: 我们可以该输出为 ( μ , L ) (\mu, L) (μ,L)( L L L为三角矩阵, 且对角线元素非负), 而假设 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx)的分量不独立, 其协方差函数为 L T L L^TL LTL, 则 ( z = μ + L ϵ (z=\mu + L \epsilon (z=μ+Lϵ).

p θ ( z ) = N ( 0 , I ) p_{\theta}(z)=\mathcal{N}(0, I) pθ(z)=N(0,I), 我们可以显示表达出:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Decoder (损失part2)

现在我们需要处理的是第二项, 文中这地方因为直接设计 p θ ( x , z ) p_{\theta}(x,z) pθ(x,z)不容易, 在我看来存粹是做不到的, 但是又用普通的分布代替不符合常理, 所以首先设计一个网络 g θ ( z ) g_{\theta}(z) gθ(z), 其输出为 x ^ \hat{x} x^, 然后假设 p ( x ∣ x ^ ) p(x|\hat{x}) p(xx^)的分布, 第二项就改为近似 E q ϕ ( z ∣ x ) p θ ( x ∣ x ^ ) \mathbb{E}_{q_{\phi}(z|x)}p_{\theta}(x|\hat{x}) Eqϕ(zx)pθ(xx^).

这么做的好处是显而易见的, 因为Decoder部分, 我们可以通过给定一个 z z z然后获得一个 x ^ \hat{x} x^, 这是很有用的东西, 但是我认为这种不是很合理, 因为除非 g g g是可逆的, 那么 p θ ( x ∣ z ) = p θ ( x ∣ x ^ ) p_{\theta}(x|z)= p _{\theta}(x|\hat{x}) pθ(xz)=pθ(xx^) (当然, 别无选择).

伯努利分布

此时 x ^ = g ( z ) \hat{x}=g(z) x^=g(z) x = 1 x=1 x=1的概率, 则此时第二项的损失为
log ⁡ p ( x ∣ x ^ ) = ∑ i = 1 x i log ⁡ x ^ i + ( 1 − x i ) log ⁡ ( 1 − x ^ i ) , \log p(\mathbf{x}| \hat{\mathbf{x}})= \sum_{i=1} x_i \log \hat{x}_i + (1-x_i) \log (1- \hat{x}_i), logp(xx^)=i=1xilogx^i+(1xi)log(1x^i),
为(二分类)交叉熵损失.

高斯分布

一种简单粗暴的, p ( x ∣ x ^ ) = N ( x ^ , σ 2 I ) p(x|\hat{x})=\mathcal{N}(\hat{x},\sigma^2 I) p(xx^)=N(x^,σ2I), 此时损失为类平方损失, 文中也有别的变换.

代码

import torch
import torch.nn as nn


class Loss(nn.Module):
    def __init__(self, part2):
        super(Loss, self).__init__()
        self.part2 = part2

    def forward(self, mu, sigma, real, fake, lam=1):
        part1 = (1 + torch.log(sigma ** 2)
                 - mu ** 2 - sigma ** 2).sum() / 2
        part2 = self.part2(fake, real)
        return part1 + lam * part2
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值