扩散模型 (Diffusion Model) 简要介绍与源码分析
前言
近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.
于是找到他们的开源产品体验了一把, 在线体验地址在 https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上: https://github.com/CompVis/stable-diffusion), 在搜索框中输入 "A dog flying in the sky" (一只狗在天空飞翔), 生成效果如下:
Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.
当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.
广而告之
可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.
另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.
总览
本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.
发现知乎影响了公式的排版, 如果觉得看着影响逻辑, 可以到 https://blog.csdn.net/Eric_1993/article/details/127455977 获得更好的阅读体验, 本篇公式比较多. 发现知乎上公式编号没了... CSDN 的 markdown 编辑器还是挺好用的. 只关注代码的话可以跳到源码分析进行阅读.
参考文章
在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:
- Lilian 的博客, 内容非常非常详实, 干货十足, 而且每篇文章都极其用心, 向大佬学习: What are Diffusion Models?
- ewrfcas 的知乎, 公式推导补充了更多的细节: 由浅入深了解Diffusion Model
- Lilian 的博客, 介绍变分自动编码器 VAE: From Autoencoder to Beta-VAE, Diffusion Model 需要从分布中随机采样样本, 该过程无法求导, 需要使用到 VAE 中介绍的重参数技巧.
- Denoising Diffusion Probabilistic Models 论文,
- 其 TF 源码位于: https://github.com/hojonathanho/diffusion, 源码介绍以该版本为主
- PyTorch 的开源实现: https://github.com/lucidrains/denoising-diffusion-pytorch, 核心逻辑和上面 Tensorflow 版本是一致的, Stable Diffusion 参考的是 pytorch 版本的代码.
扩散模型介绍
基本原理
Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.
具体来说, 前向阶段在原始图像 �0 上逐步增加噪声, 每一步得到的图像 �� 只和上一步的结果 ��−1 相关, 直至第 � 步的图像 �� 变为纯高斯噪声. 前向阶段图示如下:
而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声 ��, 通过逐步去噪, 直至最终将原图像 �0 给恢复出来, 逆向阶段图示如下:
模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式。
前向阶段
由于前向过程中图像 �� 只和上一时刻的 ��−1 有关, 该过程可以视为马尔科夫过程, 满足:
�(�1:�∣�0)=∏�=1��(��∣��−1)�(��∣��−1)=�(��;1−����−1,���),
其中 ��∈(0,1) 为高斯分布的方差超参, 并满足 �1<�2<…<��. 另外公式 (2) 中为何均值 ��−1 前乘上系数 1−����−1 的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过 重参数技巧 采样得到 ��.
重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布 �∼�(�;�,�2�) 中采样样本 �, 可以通过引入随机变量 �∼�(0,�), 使得 �=�+�⊙�, 此时 � 依旧具有随机性, 且服从高斯分布 �(�,�2�), 同时 � 与 � (通常由网络生成) 可导.
简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样 �� 的方法, 即生成随机变量 ��∼�(0,�), 然后令 ��=1−��, 以及 �¯�=∏�=1���, 从而可以得到:
��=1−����−1+���1 where �1,�2,…∼�(0,�),reparameter trick;=����−1+1−���1=��(��−1��−2+1−��−1�2)+1−���1(3-1)=����−1��−2+(��(1−��−1)�2+1−���1)(3-2)=����−1��−2+1−����−1�¯2 where �¯2∼�(0,�);=…=�¯��0+1−�¯��¯�.
其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有 �(0,�12�)+�(0,�22�)∼�(0,(�12+�22)�), 因此:
��(1−��−1)�2∼�(0,��(1−��−1)�)1−���1∼�(0,(1−��)�)��(1−��−1)�2+1−���1∼�(0,[��(1−��−1)+(1−��)]�)=�(0,(1−����−1)�).
注意公式 (3-2) 中 �¯2∼�(0,�), 因此还需乘上 1−����−1. 从公式 (3) 可以看出
�(��∣�0)=�(��;�¯��0,(1−�¯�)�)
注意由于 ��∈(0,1) 且 �1<…<��, 而 ��=1−��, 因此 ��∈(0,1) 并且有 �1>…>��, 另外由于 �¯�=∏�=1���, 因此当 �→∞ 时, �¯�→0 以及 (1−�¯�)→1, 此时 ��∼�(0,�). 从这里的推导来看, 在公式 (2) 中的均值 ��−1 前乘上系数 1−����−1 会使得 �� 最后收敛到标准高斯分布.
逆向阶段
前向阶段是加噪声的过程, 而逆向阶段则是将噪声去除, 如果能得到逆向过程的分布 �(��−1∣��), 那么通过输入高斯噪声 ��∼�(0,�), 我们将生成一个真实的样本. 注意到当 �� 足够小时, �(��−1∣��) 也是高斯分布, 具体的证明在 ewrfcas 的知乎文章: 由浅入深了解Diffusion Model 推荐的论文中: On the theory of stochastic processes, with particular reference to applications
. 我大致看了一下, 哈哈, 没太看明白, 不过想到这个不是我关注的重点, 因此 pass. 由于我们无法直接推断 �(��−1∣��), 因此我们将使用深度学习模型 �� 去拟合分布 �(��−1∣��), 模型参数为 �:
��(�0:�)=�(��)∏�=1���(��−1∣��)��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�))
注意到, 虽然我们无法直接求得 �(��−1∣��) (注意这里是 � 而不是模型 ��), 但在知道 �0 的情况下, 可以通过贝叶斯公式得到 �(��−1∣��,�0) 为:
�(��−1∣��,�0)=�(��−1;�~(��,�0),�~��)
推导过程如下:
方差均值与无关�(��−1|��,�0)=�(��|��−1,�0)�(��−1|�0)�(��|�0)∝exp(−12((��−����−1)2��+(��−1−�¯�−1�0)21−�¯�−1−(��−�¯��0)21−�¯�))=exp(−12(��2−2������−1+����−12��+��−12−2�¯�−1�0��−1+�¯�−1�021−�¯�−1−(��−�¯��0)21−�¯�))=exp(−12((����+11−�¯�−1)��−12⏟��−1 方差 −(2������+2�¯�−11−�¯�−1�0)��−1⏟��−1 均值 +�(��,�0)⏟与 ��−1 无关 ))
上面推导过程中, 通过贝叶斯公式巧妙的将逆向过程转换为前向过程, 且最终得到的概率密度函数和高斯概率密度函数的指数部分 exp(−(�−�)22�2)=exp(−12(1�2�2−2��2�+�2�2)) 能对应, 即有:
�~�=1/(����+11−�¯�−1)=1/(��−�¯�+����(1−�¯�−1))=1−�¯�−11−�¯�⋅���~�(��,�0)=(������+�¯�−11−�¯�−1�0)/(����+11−�¯�−1)=(������+�¯�−11−�¯�−1�0)1−�¯�−11−�¯�⋅��=��(1−�¯�−1)1−�¯���+�¯�−1��1−�¯��0
通过公式 (8) 和公式 (9), 我们能得到 �(��−1∣��,�0) (见公式 (7)) 的分布. 此外由于公式 (3) 揭示的 �� 和 �0 之间的关系: ��=�¯��0+1−�¯��¯�, 可以得到
�0=1�¯�(��−1−�¯���)
代入公式 (9) 中得到:
�~�=��(1−�¯�−1)1−�¯���+�¯�−1��1−�¯�1�¯�(��−1−�¯���)=1��(��−1−��1−�¯���)
补充一下公式 (11) 的详细推导过程:
前面说到, 我们将使用深度学习模型 �� 去拟合逆向过程的分布 �(��−1∣��), 由公式 (6) 知 ��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�)), 我们希望训练模型 ��(��,�) 以预估 �~�=1��(��−1−��1−�¯���). 由于 �� 在训练阶段会作为输入, 因此它是已知的, 我们可以转而让模型去预估噪声 ��, 即令:
��(��,�)=1��(��−1−��1−�¯���(��,�))Thus ��−1=�(��−1;1��(��−1−��1−�¯���(��,�)),��(��,�))
模型训练
前面谈到, 逆向阶段让模型去预估噪声 ��(��,�), 那么应该如何设计 Loss 函数 ? 我们的目标是在真实数据分布下, 最大化模型预测分布的对数似然, 即优化在 �0∼�(�0) 下的 ��(�0) 交叉熵:
�=��(�0)[−log��(�0)]
和 变分自动编码器 VAE 类似, 使用 Variational Lower Bound 来优化: −log��(�0) :
注注意散度非负与无关−log��(�0)≤−log��(�0)+���(�(�1:�∣�0)‖��(�1:�∣�0));注: 注意KL散度非负=−log��(�0)+��(�1:�∣�0)[log�(�1:�∣�0)��(�0:�)/��(�0)]; where ��(�1:�∣�0)=��(�0:�)��(�0)=−log��(�0)+��(�1:�∣�0)[log�(�1:�∣�0)��(�0:�)+log��(�0)⏟与q无关 ]=��(�1:�∣�0)[log�(�1:�∣�0)��(�0:�)].
对公式 (15) 左右两边取期望 ��(�0), 利用到重积分中的 Fubini 定理 可得:
定理����=��(�0)(��(�1:�∣�0)[log�(�1:�∣�0)��(�0:�)])=��(�0:�)[log�(�1:�∣�0)��(�0:�)]⏟Fubini定理 ≥��(�0)[−log��(�0)]
因此最小化 ���� 就可以优化公式 (14) 中的目标函数. 之后对 ���� 做进一步的推导, 这部分的详细推导见上面的参考文章, 最终的结论是:
����=��+��−1+…+�0��=���(�(��|�0)||��(��))��=���(�(��|��+1,�0)||��(��|��+1));1≤�≤�−1�0=−log��(�0|�1)
最终是优化两个高斯分布 �(��|��−1,�0)=�(��−1;�~(��,�0),�~��) (详见公式 (7)) 与 ��(��|��+1)=�(��−1;��(��,�),Σ�) (详见公式(6), 此为模型预估的分布)之间的 KL 散度. 由于多元高斯分布的 KL 散度存在闭式解, 详见: Multivariate_normal_distributions, 从而可以得到:
其中为高斯噪声为模型学习的噪声��=��0,�[12‖��(��,�)‖22‖�~�(��,�0)−��(��,�)‖2]=��0,�[12‖��‖22‖1��(��−1−��1−�¯���)−1��(��−1−��1−�¯���(��,�))‖2]=��0,�[(1−��)22��(1−�¯�)‖��‖22‖��−��(��,�)‖2];其中��为高斯噪声,��为模型学习的噪声=��0,�[(1−��)22��(1−�¯�)‖��‖22‖��−��(�¯��0+1−�¯���,�)‖2]
DDPM 将 Loss 简化为如下形式:
��simple =��0,��[‖��−��(�¯��0+1−�¯���,�)‖2]
因此 Diffusion 模型的目标函数即是学习高斯噪声 �� 和 �� (来自模型输出) 之间的 MSE loss.
最终算法
最终 DDPM 的算法流程如下:
训练阶段重复如下步骤:
- 从数据集中采样 �0
- 随机选取 time step �
- 生成高斯噪声 ��∈�(0,�)
- 调用模型预估 ��(�¯��0+1−�¯���,�)
- 计算噪声之间的 MSE Loss: ‖��−��(�¯��0+1−�¯���,�)‖2, 并利用反向传播算法训练模型.
逆向阶段采用如下步骤进行采样:
- 从高斯分布采样 ��
- 按照 �,…,1 的顺序进行迭代:
- 如果 �=1, 令 �=0; 如果 �>1, 从高斯分布中采样 �∼�(0,�)
- 利用公式 (12) 学习出均值 ��(��,�)=1��(��−1−��1−�¯���(��,�)), 并利用公式 (8) 计算均方差 ��=�~�=1−�¯�−11−�¯�⋅��
- 通过重参数技巧采样 ��−1=��(��,�)+���
- 经过以上过程的迭代, 最终恢复 �0.
源码分析
DDPM 文章以及代码的相关信息如下:
- Denoising Diffusion Probabilistic Models 论文,
- 其 TF 源码位于: https://github.com/hojonathanho/diffusion, 源码介绍以该版本为主
- PyTorch 的开源实现: https://github.com/lucidrains/denoising-diffusion-pytorch, 核心逻辑和上面 Tensorflow 版本是一致的, Stable Diffusion 参考的是 pytorch 版本的代码.
本文以分析 Tensorflow 源码为主, Pytorch 版本的代码和 Tensorflow 版本的实现逻辑大体不差的, 变量名字啥的都类似, 阅读起来不会有啥门槛. Tensorlow 源码对 Diffusion 模型的实现位于 diffusion_utils_2.py, 模型本身的分析以该文件为主.
训练阶段
以 CIFAR 数据集为例.
在 run_cifar.py 中进行前向传播计算 Loss:
- 第 6 行随机选出 �∼Uniform({1,…,�})
- 第 7 行
training_losses
定义在 GaussianDiffusion2 中, 计算噪声间的 MSE Loss.
进入 GaussianDiffusion2 中, 看到初始化函数中定义了诸多变量, 我在注释中使用公式的方式进行了说明:
下面进入到 training_losses
函数中:
- 第 19 行:
self.model_mean_type
默认是eps
, 模型学习的是噪声, 因此target
是第 6 行定义的noise
, 即 �� - 第 9 行: 调用
self.q_sample
计算 ��, 即公式 (3) ��=�¯��0+1−�¯��� - 第 21 行:
denoise_fn
是定义在 unet.py 中的UNet
模型, 只需知道它的输入和输出大小相同; 结合第 9 行得到的 ��, 得到模型预估的噪声: ��(�¯��0+1−�¯���,�) - 第 23 行: 计算两个噪声之间的 MSE: ‖��−��(�¯��0+1−�¯���,�)‖2, 并利用反向传播算法训练模型
上面第 9 行定义的 self.q_sample
详情如下:
- 第 13 行的
q_sample
已经介绍过, 不多说. - 第 2 行的
_extract
在代码中经常被使用到, 看到它只需知道它是用来提取系数的即可. 引入输入是一个 Batch, 里面的每个样本都会随机采样一个 time step �, 因此需要使用tf.gather
来将 ��¯ 之类选出来, 然后将系数 reshape 为[B, 1, 1, ....]
的形式, 目的是为了利用 broadcasting 机制和 �� 这个 Tensor 相乘.
前向的训练阶段代码实现非常简单, 下面看逆向阶段
逆向阶段
逆向阶段代码定义在 GaussianDiffusion2 中:
- 第 5 行生成高斯噪声 ��, 然后对其不断去噪直至恢复原始图像
- 第 11 行的
self.p_sample
就是公式 (6) ��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�)) 的过程, 使用模型来预估 ��(��,�) 以及 Σ�(��,�) - 第 12 行的
denoise_fn
在前面说过, 是定义在 unet.py 中的UNet
模型;img_
表示 ��. - 第 13 行的
noise_fn
则默认是tf.random_normal
, 用于生成高斯噪声.
进入 p_sample
函数:
- 第 7 行调用
self.p_mean_variance
生成 ��(��,�) 以及 log(Σ�(��,�)), 其中 Σ�(��,�) 通过计算 �~� 得到. - 第 11 行从高斯分布中采样 �
- 第 18 行通过重参数技巧采样 ��−1=��(��,�)+���, 其中 ��=�~�
进入 self.p_mean_variance
函数:
- 第 6 行调用模型
denoise_fn
, 通过输入 ��, 输出得到噪声 �� - 第 19 行
self.model_var_type
默认为fixedlarge
, 但我当时看fixedsmall
比较爽, 因此model_variance
和model_log_variance
分别为 �~�=1−�¯�−11−�¯�⋅�� (见公式 8), 以及 log�~� - 第 29 行调用
self._predict_xstart_from_eps
函数, 利用公式 (10) 得到 �0=1�¯�(��−1−�¯���) - 第 30 行调用
self.q_posterior_mean_variance
通过公式 (9) 得到 ��(��,�0)=��(1−�¯�−1)1−�¯���+�¯�−1��1−�¯��0
self._predict_xstart_from_eps
函数相亲如下:
- 该函数计算 �0=1�¯�(��−1−�¯���)
self.q_posterior_mean_variance
函数详情如下:
- 相关说明见注释, 另外发现对于 ��(��,�0) 的计算使用的是公式 (9) ��(��,�0)=��(1−�¯�−1)1−�¯���+�¯�−1��1−�¯��0 而不是进一步推导后的公式 (11) ��(��,�0)=1��(��−1−��1−�¯���).
总结
写文章真的挺累的, 好处是, 我发现写之前我以为理解了, 但写的过程中又发现有些地方理解的不对. 写完后才终于把逻辑理顺.