扩散模型 (Diffusion Model) 简要介绍与源码分析

扩散模型 (Diffusion Model) 简要介绍与源码分析

扩散模型 (Diffusion Model) 简要介绍与源码分析

前言

近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.

于是找到他们的开源产品体验了一把, 在线体验地址在 https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上: https://github.com/CompVis/stable-diffusion), 在搜索框中输入 "A dog flying in the sky" (一只狗在天空飞翔), 生成效果如下:

Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.

当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.

广而告之

可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.

另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.

总览

本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.

发现知乎影响了公式的排版, 如果觉得看着影响逻辑, 可以到 https://blog.csdn.net/Eric_1993/article/details/127455977 获得更好的阅读体验, 本篇公式比较多. 发现知乎上公式编号没了... CSDN 的 markdown 编辑器还是挺好用的. 只关注代码的话可以跳到源码分析进行阅读.

参考文章

在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:

扩散模型介绍

基本原理

Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.

具体来说, 前向阶段在原始图像 �0 上逐步增加噪声, 每一步得到的图像 �� 只和上一步的结果 ��−1 相关, 直至第 � 步的图像 �� 变为纯高斯噪声. 前向阶段图示如下:

而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声 ��, 通过逐步去噪, 直至最终将原图像 �0 给恢复出来, 逆向阶段图示如下:

模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式。

前向阶段

由于前向过程中图像 �� 只和上一时刻的 ��−1 有关, 该过程可以视为马尔科夫过程, 满足:

�(�1:�∣�0)=∏�=1��(��∣��−1)�(��∣��−1)=�(��;1−����−1,���),

其中 ��∈(0,1) 为高斯分布的方差超参, 并满足 �1<�2<…<��. 另外公式 (2) 中为何均值 ��−1 前乘上系数 1−����−1 的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过 重参数技巧 采样得到 ��.

重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布 �∼�(�;�,�2�) 中采样样本 �, 可以通过引入随机变量 �∼�(0,�), 使得 �=�+�⊙�, 此时 � 依旧具有随机性, 且服从高斯分布 �(�,�2�), 同时 � 与 � (通常由网络生成) 可导.

简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样 �� 的方法, 即生成随机变量 ��∼�(0,�), 然后令 ��=1−��, 以及 �¯�=∏�=1���, 从而可以得到:

��=1−����−1+���1 where �1,�2,…∼�(0,�),reparameter trick;=����−1+1−���1=��(��−1��−2+1−��−1�2)+1−���1(3-1)=����−1��−2+(��(1−��−1)�2+1−���1)(3-2)=����−1��−2+1−����−1�¯2 where �¯2∼�(0,�);=…=�¯��0+1−�¯��¯�.

其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有 �(0,�12�)+�(0,�22�)∼�(0,(�12+�22)�), 因此:

��(1−��−1)�2∼�(0,��(1−��−1)�)1−���1∼�(0,(1−��)�)��(1−��−1)�2+1−���1∼�(0,[��(1−��−1)+(1−��)]�)=�(0,(1−����−1)�).

注意公式 (3-2) 中 �¯2∼�(0,�), 因此还需乘上 1−����−1. 从公式 (3) 可以看出

�(��∣�0)=�(��;�¯��0,(1−�¯�)�)

注意由于 ��∈(0,1) 且 �1<…<��, 而 ��=1−��, 因此 ��∈(0,1) 并且有 �1>…>��, 另外由于 �¯�=∏�=1���, 因此当 �→∞ 时, �¯�→0 以及 (1−�¯�)→1, 此时 ��∼�(0,�). 从这里的推导来看, 在公式 (2) 中的均值 ��−1 前乘上系数 1−����−1 会使得 �� 最后收敛到标准高斯分布.

逆向阶段

前向阶段是加噪声的过程, 而逆向阶段则是将噪声去除, 如果能得到逆向过程的分布 �(��−1∣��), 那么通过输入高斯噪声 ��∼�(0,�), 我们将生成一个真实的样本. 注意到当 �� 足够小时, �(��−1∣��) 也是高斯分布, 具体的证明在 ewrfcas 的知乎文章: 由浅入深了解Diffusion Model 推荐的论文中: On the theory of stochastic processes, with particular reference to applications. 我大致看了一下, 哈哈, 没太看明白, 不过想到这个不是我关注的重点, 因此 pass. 由于我们无法直接推断 �(��−1∣��), 因此我们将使用深度学习模型 �� 去拟合分布 �(��−1∣��), 模型参数为 �:

��(�0:�)=�(��)∏�=1���(��−1∣��)��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�))

注意到, 虽然我们无法直接求得 �(��−1∣��) (注意这里是 � 而不是模型 ��), 但在知道 �0 的情况下, 可以通过贝叶斯公式得到 �(��−1∣��,�0) 为:

�(��−1∣��,�0)=�(��−1;�~(��,�0),�~��)

推导过程如下:

方差均值与无关�(��−1|��,�0)=�(��|��−1,�0)�(��−1|�0)�(��|�0)∝exp⁡(−12((��−����−1)2��+(��−1−�¯�−1�0)21−�¯�−1−(��−�¯��0)21−�¯�))=exp⁡(−12(��2−2������−1+����−12��+��−12−2�¯�−1�0��−1+�¯�−1�021−�¯�−1−(��−�¯��0)21−�¯�))=exp⁡(−12((����+11−�¯�−1)��−12⏟��−1 方差 −(2������+2�¯�−11−�¯�−1�0)��−1⏟��−1 均值 +�(��,�0)⏟与 ��−1 无关 ))

上面推导过程中, 通过贝叶斯公式巧妙的将逆向过程转换为前向过程, 且最终得到的概率密度函数和高斯概率密度函数的指数部分 exp⁡(−(�−�)22�2)=exp⁡(−12(1�2�2−2��2�+�2�2)) 能对应, 即有:

�~�=1/(����+11−�¯�−1)=1/(��−�¯�+����(1−�¯�−1))=1−�¯�−11−�¯�⋅���~�(��,�0)=(������+�¯�−11−�¯�−1�0)/(����+11−�¯�−1)=(������+�¯�−11−�¯�−1�0)1−�¯�−11−�¯�⋅��=��(1−�¯�−1)1−�¯���+�¯�−1��1−�¯��0

通过公式 (8) 和公式 (9), 我们能得到 �(��−1∣��,�0) (见公式 (7)) 的分布. 此外由于公式 (3) 揭示的 �� 和 �0 之间的关系: ��=�¯��0+1−�¯��¯�, 可以得到

�0=1�¯�(��−1−�¯���)

代入公式 (9) 中得到:

�~�=��(1−�¯�−1)1−�¯���+�¯�−1��1−�¯�1�¯�(��−1−�¯���)=1��(��−1−��1−�¯���)

补充一下公式 (11) 的详细推导过程:

前面说到, 我们将使用深度学习模型 �� 去拟合逆向过程的分布 �(��−1∣��), 由公式 (6) 知 ��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�)), 我们希望训练模型 ��(��,�) 以预估 �~�=1��(��−1−��1−�¯���). 由于 �� 在训练阶段会作为输入, 因此它是已知的, 我们可以转而让模型去预估噪声 ��, 即令:

��(��,�)=1��(��−1−��1−�¯���(��,�))Thus ��−1=�(��−1;1��(��−1−��1−�¯���(��,�)),��(��,�))

模型训练

前面谈到, 逆向阶段让模型去预估噪声 ��(��,�), 那么应该如何设计 Loss 函数 ? 我们的目标是在真实数据分布下, 最大化模型预测分布的对数似然, 即优化在 �0∼�(�0) 下的 ��(�0) 交叉熵:

�=��(�0)[−log⁡��(�0)]

和 变分自动编码器 VAE 类似, 使用 Variational Lower Bound 来优化: −log⁡��(�0) :

注注意散度非负与无关−log⁡��(�0)≤−log⁡��(�0)+���(�(�1:�∣�0)‖��(�1:�∣�0));注: 注意KL散度非负=−log⁡��(�0)+��(�1:�∣�0)[log⁡�(�1:�∣�0)��(�0:�)/��(�0)]; where ��(�1:�∣�0)=��(�0:�)��(�0)=−log⁡��(�0)+��(�1:�∣�0)[log⁡�(�1:�∣�0)��(�0:�)+log⁡��(�0)⏟与q无关 ]=��(�1:�∣�0)[log⁡�(�1:�∣�0)��(�0:�)].

对公式 (15) 左右两边取期望 ��(�0), 利用到重积分中的 Fubini 定理 可得:

定理����=��(�0)(��(�1:�∣�0)[log⁡�(�1:�∣�0)��(�0:�)])=��(�0:�)[log⁡�(�1:�∣�0)��(�0:�)]⏟Fubini定理 ≥��(�0)[−log⁡��(�0)]

因此最小化 ���� 就可以优化公式 (14) 中的目标函数. 之后对 ���� 做进一步的推导, 这部分的详细推导见上面的参考文章, 最终的结论是:

����=��+��−1+…+�0��=���(�(��|�0)||��(��))��=���(�(��|��+1,�0)||��(��|��+1));1≤�≤�−1�0=−log⁡��(�0|�1)

最终是优化两个高斯分布 �(��|��−1,�0)=�(��−1;�~(��,�0),�~��) (详见公式 (7)) 与 ��(��|��+1)=�(��−1;��(��,�),Σ�) (详见公式(6), 此为模型预估的分布)之间的 KL 散度. 由于多元高斯分布的 KL 散度存在闭式解, 详见: Multivariate_normal_distributions, 从而可以得到:

其中为高斯噪声为模型学习的噪声��=��0,�[12‖��(��,�)‖22‖�~�(��,�0)−��(��,�)‖2]=��0,�[12‖��‖22‖1��(��−1−��1−�¯���)−1��(��−1−��1−�¯���(��,�))‖2]=��0,�[(1−��)22��(1−�¯�)‖��‖22‖��−��(��,�)‖2];其中��为高斯噪声,��为模型学习的噪声=��0,�[(1−��)22��(1−�¯�)‖��‖22‖��−��(�¯��0+1−�¯���,�)‖2]

DDPM 将 Loss 简化为如下形式:

��simple =��0,��[‖��−��(�¯��0+1−�¯���,�)‖2]

因此 Diffusion 模型的目标函数即是学习高斯噪声 �� 和 �� (来自模型输出) 之间的 MSE loss.

最终算法

最终 DDPM 的算法流程如下:

训练阶段重复如下步骤:

  • 从数据集中采样 �0
  • 随机选取 time step �
  • 生成高斯噪声 ��∈�(0,�)
  • 调用模型预估 ��(�¯��0+1−�¯���,�)
  • 计算噪声之间的 MSE Loss: ‖��−��(�¯��0+1−�¯���,�)‖2, 并利用反向传播算法训练模型.

逆向阶段采用如下步骤进行采样:

  • 从高斯分布采样 ��
  • 按照 �,…,1 的顺序进行迭代:
    • 如果 �=1, 令 �=0; 如果 �>1, 从高斯分布中采样 �∼�(0,�)
    • 利用公式 (12) 学习出均值 ��(��,�)=1��(��−1−��1−�¯���(��,�)), 并利用公式 (8) 计算均方差 ��=�~�=1−�¯�−11−�¯�⋅��
    • 通过重参数技巧采样 ��−1=��(��,�)+���
  • 经过以上过程的迭代, 最终恢复 �0.

源码分析

DDPM 文章以及代码的相关信息如下:

本文以分析 Tensorflow 源码为主, Pytorch 版本的代码和 Tensorflow 版本的实现逻辑大体不差的, 变量名字啥的都类似, 阅读起来不会有啥门槛. Tensorlow 源码对 Diffusion 模型的实现位于 diffusion_utils_2.py, 模型本身的分析以该文件为主.

训练阶段

以 CIFAR 数据集为例.

在 run_cifar.py 中进行前向传播计算 Loss:

  • 第 6 行随机选出 �∼Uniform({1,…,�})
  • 第 7 行 training_losses 定义在 GaussianDiffusion2 中, 计算噪声间的 MSE Loss.

进入 GaussianDiffusion2 中, 看到初始化函数中定义了诸多变量, 我在注释中使用公式的方式进行了说明:

下面进入到 training_losses 函数中:

  • 第 19 行: self.model_mean_type 默认是 eps, 模型学习的是噪声, 因此 target 是第 6 行定义的 noise, 即 ��
  • 第 9 行: 调用 self.q_sample 计算 ��, 即公式 (3) ��=�¯��0+1−�¯���
  • 第 21 行: denoise_fn 是定义在 unet.py 中的 UNet 模型, 只需知道它的输入和输出大小相同; 结合第 9 行得到的 ��, 得到模型预估的噪声: ��(�¯��0+1−�¯���,�)
  • 第 23 行: 计算两个噪声之间的 MSE: ‖��−��(�¯��0+1−�¯���,�)‖2, 并利用反向传播算法训练模型

上面第 9 行定义的 self.q_sample 详情如下:

  • 第 13 行的 q_sample 已经介绍过, 不多说.
  • 第 2 行的 _extract 在代码中经常被使用到, 看到它只需知道它是用来提取系数的即可. 引入输入是一个 Batch, 里面的每个样本都会随机采样一个 time step �, 因此需要使用 tf.gather 来将 ��¯ 之类选出来, 然后将系数 reshape 为 [B, 1, 1, ....] 的形式, 目的是为了利用 broadcasting 机制和 �� 这个 Tensor 相乘.

前向的训练阶段代码实现非常简单, 下面看逆向阶段

逆向阶段

逆向阶段代码定义在 GaussianDiffusion2 中:

  • 第 5 行生成高斯噪声 ��, 然后对其不断去噪直至恢复原始图像
  • 第 11 行的 self.p_sample 就是公式 (6) ��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�)) 的过程, 使用模型来预估 ��(��,�) 以及 Σ�(��,�)
  • 第 12 行的 denoise_fn 在前面说过, 是定义在 unet.py 中的 UNet 模型; img_ 表示 ��.
  • 第 13 行的 noise_fn 则默认是 tf.random_normal, 用于生成高斯噪声.

进入 p_sample 函数:

  • 第 7 行调用 self.p_mean_variance 生成 ��(��,�) 以及 log⁡(Σ�(��,�)), 其中 Σ�(��,�) 通过计算 �~� 得到.
  • 第 11 行从高斯分布中采样 �
  • 第 18 行通过重参数技巧采样 ��−1=��(��,�)+���, 其中 ��=�~�

进入 self.p_mean_variance 函数:

  • 第 6 行调用模型 denoise_fn, 通过输入 ��, 输出得到噪声 ��
  • 第 19 行 self.model_var_type 默认为 fixedlarge, 但我当时看 fixedsmall 比较爽, 因此 model_variance 和 model_log_variance 分别为 �~�=1−�¯�−11−�¯�⋅�� (见公式 8), 以及 log⁡�~�
  • 第 29 行调用 self._predict_xstart_from_eps 函数, 利用公式 (10) 得到 �0=1�¯�(��−1−�¯���)
  • 第 30 行调用 self.q_posterior_mean_variance 通过公式 (9) 得到 ��(��,�0)=��(1−�¯�−1)1−�¯���+�¯�−1��1−�¯��0

self._predict_xstart_from_eps 函数相亲如下:

  • 该函数计算 �0=1�¯�(��−1−�¯���)

self.q_posterior_mean_variance 函数详情如下:

  • 相关说明见注释, 另外发现对于 ��(��,�0) 的计算使用的是公式 (9) ��(��,�0)=��(1−�¯�−1)1−�¯���+�¯�−1��1−�¯��0 而不是进一步推导后的公式 (11) ��(��,�0)=1��(��−1−��1−�¯���).

总结

写文章真的挺累的, 好处是, 我发现写之前我以为理解了, 但写的过程中又发现有些地方理解的不对. 写完后才终于把逻辑理顺.

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI周红伟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值