扩散模型 (Diffusion Model) 简要介绍与源码分析

 推荐 原创

珍妮的选择2023-03-15 22:09:01博主文章分类:机器学习©著作权

文章标签扩散模型DDPMStable-Diffusion深度学习计算机视觉文章分类计算机视觉人工智能yyds干货盘点阅读数1394

扩散模型 (Diffusion Model) 简要介绍与源码分析

Table of Contents

前言

近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的  Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.

于是找到他们的开源产品体验了一把, 在线体验地址在  https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上:  https://github.com/CompVis/stable-diffusion), 在搜索框中输入 “A dog flying in the sky” (一只狗在天空飞翔), 生成效果如下:

扩散模型 (Diffusion Model) 简要介绍与源码分析_深度学习

Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.

当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.

广而告之

可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.

另外可以看看知乎专栏  PoorMemory-机器学习, 以后文章也会发在知乎专栏中.

总览

本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.

参考文章

在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:

扩散模型介绍

基本原理

Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.

具体来说, 前向阶段在原始图像 �0x0​ 上逐步增加噪声, 每一步得到的图像 ��xt​ 只和上一步的结果 ��−1xt−1​ 相关, 直至第 �T 步的图像 ��xT​ 变为纯高斯噪声. 前向阶段图示如下:

扩散模型 (Diffusion Model) 简要介绍与源码分析_DDPM_02

而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声 ��xT​, 通过逐步去噪, 直至最终将原图像 �0x0​ 给恢复出来, 逆向阶段图示如下:

扩散模型 (Diffusion Model) 简要介绍与源码分析_深度学习_03

模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式,

前向阶段

由于前向过程中图像 ��xt​ 只和上一时刻的 ��−1xt−1​ 有关, 该过程可以视为马尔科夫过程, 满足:

�(�1:�∣�0)=∏�=1��(��∣��−1)�(��∣��−1)=�(��;1−����−1,���),q(x1:T​∣x0​)q(xt​∣xt−1​)​=t=1∏T​q(xt​∣xt−1​)=N(xt​;1−βt​​xt−1​,βt​I),​​

其中 ��∈(0,1)βt​∈(0,1) 为高斯分布的方差超参, 并满足 �1<�2<…<��β1​<β2​<…<βT​. 另外公式 (2) 中为何均值 ��−1xt−1​ 前乘上系数 1−����−11−βt​​xt−1​ 的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过  重参数技巧 采样得到 ��xt​.

 重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布 �∼�(�;�,�2�)z∼N(z;μ,σ2I) 中采样样本 �z, 可以通过引入随机变量 �∼�(0,�)ϵ∼N(0,I), 使得 �=�+�⊙�z=μ+σ⊙ϵ, 此时 �z 依旧具有随机性, 且服从高斯分布 �(�,�2�)N(μ,σ2I), 同时 �μ 与 �σ (通常由网络生成) 可导.

简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样 ��xt​ 的方法, 即生成随机变量 ��∼�(0,�)ϵt​∼N(0,I),
然后令 ��=1−��αt​=1−βt​, 以及 ��‾=∏�=1���αt​​=∏i=1T​αt​, 从而可以得到:

扩散模型 (Diffusion Model) 简要介绍与源码分析_扩散模型_04

其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有 �(0,�12�)+�(0,�22�)∼�(0,(�12+�22)�)N(0,σ12​I)+N(0,σ22​I)∼N(0,(σ12​+σ22​)I), 因此:

��(1−��−1)�2∼�(0,��(1−��−1)�)1−���1∼�(0,(1−��)�)��(1−��−1)�2+1−���1∼�(0,[��(1−��−1)+(1−��)]�)=�(0,(1−����−1)�).​at​(1−αt−1​)​ϵ2​∼N(0,at​(1−αt−1​)I)1−αt​​ϵ1​∼N(0,(1−αt​)I)at​(1−αt−1​)​ϵ2​+1−αt​​ϵ1​∼N(0,[αt​(1−αt−1​)+(1−αt​)]I)=N(0,(1−αt​αt−1​)I).​

注意公式 (3-2) 中 �ˉ2∼�(0,�)ϵˉ2​∼N(0,I), 因此还需乘上 1−����−11−αt​αt−1​​. 从公式 (3) 可以看出

�(��∣�0)=�(��;�ˉ��0,(1−�ˉ�)�)q(xt​∣x0​)=N(xt​;aˉt​​x0​,(1−aˉt​)I)​

注意由于 ��∈(0,1)βt​∈(0,1) 且 �1<…<��β1​<…<βT​, 而 ��=1−��αt​=1−βt​, 因此 ��∈(0,1)αt​∈(0,1) 并且有 �1>…>��α1​>…>αT​, 另外由于 �ˉ�=∏�=1���αˉt​=∏i=1T​αt​, 因此当 �→∞T→∞ 时, �ˉ�→0αˉt​→0 以及 (1−�ˉ�)→1(1−aˉt​)→1, 此时 ��∼�(0,�)xT​∼N(0,I). 从这里的推导来看, 在公式 (2) 中的均值 ��−1xt−1​ 前乘上系数 1−����−11−βt​​xt−1​ 会使得 ��xT​ 最后收敛到标准高斯分布.

逆向阶段

前向阶段是加噪声的过程, 而逆向阶段则是将噪声去除, 如果能得到逆向过程的分布 �(��−1∣��)q(xt−1​∣xt​), 那么通过输入高斯噪声 ��∼�(0,�)xT​∼N(0,I), 我们将生成一个真实的样本. 注意到当 ��βt​ 足够小时, �(��−1∣��)q(xt−1​∣xt​) 也是高斯分布, 具体的证明在 ewrfcas 的知乎文章:  由浅入深了解Diffusion Model 推荐的论文中: On the theory of stochastic processes, with particular reference to applications. 我大致看了一下, 哈哈, 没太看明白, 不过想到这个不是我关注的重点, 因此 pass. 由于我们无法直接推断 �(��−1∣��)q(xt−1​∣xt​), 因此我们将使用深度学习模型 ��pθ​ 去拟合分布 �(��−1∣��)q(xt−1​∣xt​), 模型参数为 �θ:

��(�0:�)=�(��)∏�=1���(��−1∣��)��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�))pθ​(x0:T​)pθ​(xt−1​∣xt​)​=p(xT​)t=1∏T​pθ​(xt−1​∣xt​)=N(xt−1​;μθ​(xt​,t),Σθ​(xt​,t))​

注意到, 虽然我们无法直接求得 �(��−1∣��)q(xt−1​∣xt​) (注意这里是 �q 而不是模型 ��pθ​), 但在知道 �0x0​ 的情况下, 可以通过贝叶斯公式得到 �(��−1∣��,�0)q(xt−1​∣xt​,x0​) 为:

�(��−1∣��,�0)=�(��−1;�~(��,�0),�~��)q(xt−1​∣xt​,x0​)​=N(xt−1​;μ~​(xt​,x0​),β~​t​I)​

推导过程如下:

�(��−1∣��,�0)=�(��∣��−1,�0)�(��−1∣�0)�(��∣�0)∝exp⁡(−12((��−����−1)2��+(��−1−�ˉ�−1�0)21−�ˉ�−1−(��−�ˉ��0)21−�ˉ�))=exp⁡(−12(��2−2������−1+����−12��+��−12−2�ˉ�−1�0��−1+�ˉ�−1�021−�ˉ�−1−(��−�ˉ��0)21−�ˉ�))=exp⁡(−12((����+11−�ˉ�−1)��−12⏟��−1 方差 −(2������+2�ˉ�−11−�ˉ�−1�0)��−1⏟��−1 均值 +�(��,�0)⏟与 ��−1 无关 ))q(xt−1​∣xt​,x0​)​=q(xt​∣xt−1​,x0​)q(xt​∣x0​)q(xt−1​∣x0​)​∝exp(−21​(βt​(xt​−αt​​xt−1​)2​+1−αˉt−1​(xt−1​−αˉt−1​​x0​)2​−1−αˉt​(xt​−αˉt​​x0​)2​))=exp(−21​(βt​xt2​−2αt​​xt​xt−1​+αt​xt−12​​+1−αˉt−1​xt−12​−2αˉt−1​​x0​xt−1​+αˉt−1​x02​​−1−αˉt​(xt​−αˉt​​x0​)2​))=exp(−21​(xt−1​ 方差 (βt​αt​​+1−αˉt−1​1​)xt−12​​​−xt−1​ 均值 (βt​2αt​​​xt​+1−αˉt−1​2αˉt−1​​​x0​)xt−1​​​+与 xt−1​ 无关 C(xt​,x0​)​​))​

上面推导过程中, 通过贝叶斯公式巧妙的将逆向过程转换为前向过程, 且最终得到的概率密度函数和高斯概率密度函数的指数部分 exp⁡(−(�−�)22�2)=exp⁡(−12(1�2�2−2��2�+�2�2))exp(−2σ2(x−μ)2​)=exp(−21​(σ21​x2−σ22μ​x+σ2μ2​)) 能对应, 即有:

扩散模型 (Diffusion Model) 简要介绍与源码分析_扩散模型_05

通过公式 (8) 和公式 (9), 我们能得到 �(��−1∣��,�0)q(xt−1​∣xt​,x0​) 的分布. 此外由于公式 (3) 揭示的 ��xt​ 和 �0x0​ 之间的关系: ��=�ˉ��0+1−�ˉ��ˉ�xt​=αˉt​​x0​+1−αˉt​​ϵˉt​, 可以得到

�0=1�ˉ�(��−1−�ˉ���)x0​=αˉt​​1​(xt​−1−αˉt​​ϵt​)​

代入公式 (9) 中得到:

扩散模型 (Diffusion Model) 简要介绍与源码分析_深度学习_06

补充一下公式 (11) 的详细推导过程:

扩散模型 (Diffusion Model) 简要介绍与源码分析_计算机视觉_07

前面说到, 我们将使用深度学习模型 ��pθ​ 去拟合逆向过程的分布 �(��−1∣��)q(xt−1​∣xt​), 由上面公式知 ��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�))pθ​(xt−1​∣xt​)=N(xt−1​;μθ​(xt​,t),Σθ​(xt​,t)), 我们希望训练模型 ��(��,�)μθ​(xt​,t) 以预估 �~�=1��(��−1−��1−�ˉ���)μ~​t​=αt​​1​(xt​−1−αˉt​​1−αt​​ϵt​). 由于 ��xt​ 在训练阶段会作为输入, 因此它是已知的, 我们可以转而让模型去预估噪声 ��ϵt​, 即令:

��(��,�)=1��(��−1−��1−�ˉ���(��,�))Thus ��−1=�(��−1;1��(��−1−��1−�ˉ���(��,�)),��(��,�))μθ​(xt​,t)Thus xt−1​​=αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t))=N(xt−1​;αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t)),Σθ​(xt​,t))​

模型训练

前面谈到, 逆向阶段让模型去预估噪声 ��(��,�)ϵθ​(xt​,t), 那么应该如何设计 Loss 函数 ? 我们的目标是在真实数据分布下, 最大化模型预测分布的对数似然, 即优化在 �0∼�(�0)x0​∼q(x0​) 下的 ��(�0)pθ​(x0​) 交叉熵:

�=��(�0)[−log⁡��(�0)]L=Eq(x0​)​[−logpθ​(x0​)]​

和  变分自动编码器 VAE 类似, 使用 Variational Lower Bound 来优化: −log⁡��(�0)−logpθ​(x0​) :

扩散模型 (Diffusion Model) 简要介绍与源码分析_DDPM_08

对公式 (15) 左右两边取期望 ��(�0)Eq(x0​)​, 利用到重积分中的  Fubini 定理 可得:

����=��(�0)(��(�1:�∣�0)[log⁡�(�1:�∣�0)��(�0:�)])=��(�0:�)[log⁡�(�1:�∣�0)��(�0:�)]⏟Fubini定理 ≥��(�0)[−log⁡��(�0)]LVLB​=Fubini定理 Eq(x0​)​(Eq(x1:T​∣x0​)​[logpθ​(x0:T​)q(x1:T​∣x0​)​])=Eq(x0:T​)​[logpθ​(x0:T​)q(x1:T​∣x0​)​]​​≥Eq(x0​)​[−logpθ​(x0​)]

因此最小化 ����LVLB​ 就可以优化目标函数 �L. 之后对 ����LVLB​ 做进一步的推导, 这部分的详细推导见上面的参考文章, 最终的结论是:

����=��+��−1+…+�0��=���(�(��∣�0)∣∣��(��))��=���(�(��∣��−1,�0)∣∣��(��∣��+1));1≤�≤�−1�0=−log⁡��(�0∣�1)LVLB​LT​Lt​L0​​=LT​+LT−1​+…+L0​=DKL​(q(xT​∣x0​)∣∣pθ​(xT​))=DKL​(q(xt​∣xt−1​,x0​)∣∣pθ​(xt​∣xt+1​));1≤t≤T−1=−logpθ​(x0​∣x1​)​

最终是优化两个高斯分布 �(��∣��−1,�0)=�(��−1;�~(��,�0),�~��)q(xt​∣xt−1​,x0​)=N(xt−1​;μ~​(xt​,x0​),β~​t​I) 与 ��(��∣��+1)=�(��−1;��(��,�),Σ�)pθ​(xt​∣xt+1​)=N(xt−1​;μθ​(xt​,t),Σθ​) (此为模型预估的分布)之间的 KL 散度. 由于多元高斯分布的 KL 散度存在闭式解, 详见:  Multivariate_normal_distributions, 从而可以得到:

��=��0,�[12∥��(��,�)∥22∥�~�(��,�0)−��(��,�)∥2]=��0,�[12∥��∥22∥1��(��−1−��1−�ˉ���)−1��(��−1−��1−�ˉ���(��,�))∥2]=��0,�[(1−��)22��(1−�ˉ�)∥��∥22∥��−��(��,�)∥2];其中��为高斯噪声,��为模型学习的噪声=��0,�[(1−��)22��(1−�ˉ�)∥��∥22∥��−��(�ˉ��0+1−�ˉ���,�)∥2]Lt​​=Ex0​,ϵ​[2∥Σθ​(xt​,t)∥22​1​∥μ~​t​(xt​,x0​)−μθ​(xt​,t)∥2]=Ex0​,ϵ​[2∥Σθ​∥22​1​∥αt​​1​(xt​−1−αˉt​​1−αt​​ϵt​)−αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t))∥2]=Ex0​,ϵ​[2αt​(1−αˉt​)∥Σθ​∥22​(1−αt​)2​∥ϵt​−ϵθ​(xt​,t)∥2];其中ϵt​为高斯噪声,ϵθ​为模型学习的噪声=Ex0​,ϵ​[2αt​(1−αˉt​)∥Σθ​∥22​(1−αt​)2​∥ϵt​−ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)∥2]​

DDPM 将 Loss 简化为如下形式:

��simple =��0,��[∥��−��(�ˉ��0+1−�ˉ���,�)∥2]Ltsimple ​=Ex0​,ϵt​​[∥∥​ϵt​−ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)∥∥​2]​

因此 Diffusion 模型的目标函数即是学习高斯噪声 ��ϵt​ 和 ��ϵθ​ (来自模型输出) 之间的 MSE loss.

最终算法

最终 DDPM 的算法流程如下:

扩散模型 (Diffusion Model) 简要介绍与源码分析_扩散模型_09

训练阶段重复如下步骤:

  • 从数据集中采样 �0x0​
  • 随机选取 time step �t
  • 生成高斯噪声 ��∈�(0,�)ϵt​∈N(0,I)
  • 调用模型预估 ��(�ˉ��0+1−�ˉ���,�)ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)
  • 计算噪声之间的 MSE Loss: ∥��−��(�ˉ��0+1−�ˉ���,�)∥2∥∥​ϵt​−ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)∥∥​2, 并利用反向传播算法训练模型.

逆向阶段采用如下步骤进行采样:

  • 从高斯分布采样 ��xT​
  • 按照 �,…,1T,…,1 的顺序进行迭代:
    • 如果 �=1t=1, 令 �=0z=0; 如果 �>1t>1, 从高斯分布中采样 �∼�(0,�)z∼N(0,I)
    • 利用公式 (12) 学习出均值 ��(��,�)=1��(��−1−��1−�ˉ���(��,�))μθ​(xt​,t)=αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t)), 并利用公式 (8) 计算均方差 ��=�~�=1−�ˉ�−11−�ˉ�⋅��σt​=β~​t​​=1−αˉt​1−αˉt−1​​⋅βt​​
    • 通过重参数技巧采样 ��−1=��(��,�)+���xt−1​=μθ​(xt​,t)+σt​z
  • 经过以上过程的迭代, 最终恢复 �0x0​.

源码分析

DDPM 文章以及代码的相关信息如下:

本文以分析 Tensorflow 源码为主, Pytorch 版本的代码和 Tensorflow 版本的实现逻辑大体不差的, 变量名字啥的都类似, 阅读起来不会有啥门槛. Tensorlow 源码对 Diffusion 模型的实现位于  diffusion_utils_2.py, 模型本身的分析以该文件为主.

训练阶段

以 CIFAR 数据集为例.

在  run_cifar.py 中进行前向传播计算 Loss:

扩散模型 (Diffusion Model) 简要介绍与源码分析_深度学习_10

  • 第 6 行随机选出 �∼Uniform({1,…,�})t∼Uniform({1,…,T})
  • 第 7 行 training_losses 定义在  GaussianDiffusion2 中, 计算噪声间的 MSE Loss.

进入  GaussianDiffusion2 中, 看到初始化函数中定义了诸多变量, 我在注释中使用公式的方式进行了说明:

扩散模型 (Diffusion Model) 简要介绍与源码分析_DDPM_11

下面进入到 training_losses 函数中:

扩散模型 (Diffusion Model) 简要介绍与源码分析_扩散模型_12

  • 第 19 行: self.model_mean_type 默认是 eps, 模型学习的是噪声, 因此 target 是第 6 行定义的 noise, 即 ��ϵt​
  • 第 9 行: 调用 self.q_sample 计算 ��xt​, 即公式 (3) ��=�ˉ��0+1−�ˉ���xt​=αˉt​​x0​+1−αˉt​​ϵt​
  • 第 21 行: denoise_fn 是定义在  unet.py 中的 UNet 模型, 只需知道它的输入和输出大小相同; 结合第 9 行得到的 ��xt​, 得到模型预估的噪声: ��(�ˉ��0+1−�ˉ���,�)ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)
  • 第 23 行: 计算两个噪声之间的 MSE: ∥��−��(�ˉ��0+1−�ˉ���,�)∥2∥∥​ϵt​−ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)∥∥​2, 并利用反向传播算法训练模型

上面第 9 行定义的 self.q_sample 详情如下:

扩散模型 (Diffusion Model) 简要介绍与源码分析_深度学习_13

  • 第 13 行的 q_sample 已经介绍过, 不多说.
  • 第 2 行的 _extract 在代码中经常被使用到, 看到它只需知道它是用来提取系数的即可. 引入输入是一个 Batch, 里面的每个样本都会随机采样一个 time step �t, 因此需要使用 tf.gather 来将 ��ˉαt​ˉ​ 之类选出来, 然后将系数 reshape 为 [B, 1, 1, ....] 的形式, 目的是为了利用 broadcasting 机制和 ��xt​ 这个 Tensor 相乘.

前向的训练阶段代码实现非常简单, 下面看逆向阶段

逆向阶段

逆向阶段代码定义在  GaussianDiffusion2 中:

扩散模型 (Diffusion Model) 简要介绍与源码分析_计算机视觉_14

  • 第 5 行生成高斯噪声 ��xT​, 然后对其不断去噪直至恢复原始图像
  • 第 11 行的 self.p_sample 就是公式 (6) ��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�))pθ​(xt−1​∣xt​)=N(xt−1​;μθ​(xt​,t),Σθ​(xt​,t)) 的过程, 使用模型来预估 ��(��,�)μθ​(xt​,t) 以及 Σ�(��,�)Σθ​(xt​,t)
  • 第 12 行的 denoise_fn 在前面说过, 是定义在  unet.py 中的 UNet 模型; img_ 表示 ��xt​.
  • 第 13 行的 noise_fn 则默认是 tf.random_normal, 用于生成高斯噪声.

进入 p_sample 函数:

扩散模型 (Diffusion Model) 简要介绍与源码分析_扩散模型_15

  • 第 7 行调用 self.p_mean_variance 生成 ��(��,�)μθ​(xt​,t) 以及 log⁡(Σ�(��,�))log(Σθ​(xt​,t)), 其中 Σ�(��,�)Σθ​(xt​,t) 通过计算 �~�β~​t​ 得到.
  • 第 11 行从高斯分布中采样 �z
  • 第 18 行通过重参数技巧采样 ��−1=��(��,�)+���xt−1​=μθ​(xt​,t)+σt​z, 其中 ��=�~�σt​=β~​t​​

进入 self.p_mean_variance 函数:

扩散模型 (Diffusion Model) 简要介绍与源码分析_计算机视觉_16

  • 第 6 行调用模型 denoise_fn, 通过输入 ��xt​, 输出得到噪声 ��ϵt​
  • 第 19 行 self.model_var_type 默认为 fixedlarge, 但我当时看 fixedsmall 比较爽, 因此 model_variance 和 model_log_variance 分别为 �~�=1−�ˉ�−11−�ˉ�⋅��β~​t​=1−αˉt​1−αˉt−1​​⋅βt​ (见公式 8), 以及 log⁡�~�logβ~​t​
  • 第 29 行调用 self._predict_xstart_from_eps 函数, 利用公式 (10) 得到 �0=1�ˉ�(��−1−�ˉ���)x0​=αˉt​​1​(xt​−1−αˉt​​ϵt​)
  • 第 30 行调用 self.q_posterior_mean_variance 通过公式 (9) 得到 ��(��,�0)=��(1−�ˉ�−1)1−�ˉ���+�ˉ�−1��1−�ˉ��0μθ​(xt​,x0​)=1−αˉt​αt​​(1−αˉt−1​)​xt​+1−αˉt​αˉt−1​​βt​​x0​

self._predict_xstart_from_eps 函数详情如下:

扩散模型 (Diffusion Model) 简要介绍与源码分析_扩散模型_17

  • 该函数计算 �0=1�ˉ�(��−1−�ˉ���)x0​=αˉt​​1​(xt​−1−αˉt​​ϵt​)

self.q_posterior_mean_variance 函数详情如下:

扩散模型 (Diffusion Model) 简要介绍与源码分析_Stable-Diffusion_18

  • 相关说明见注释, 另外发现对于 ��(��,�0)μθ​(xt​,x0​) 的计算使用的是公式 (9) ��(��,�0)=��(1−�ˉ�−1)1−�ˉ���+�ˉ�−1��1−�ˉ��0μθ​(xt​,x0​)=1−αˉt​αt​​(1−αˉt−1​)​xt​+1−αˉt​αˉt−1​​βt​​x0​ 而不是进一步推导后的公式 (11) ��(��,�0)=1��(��−1−��1−�ˉ���)μθ​(xt​,x0​)=αt​​1​(xt​−1−αˉt​​1−αt​​ϵt​).

总结

本文分析了扩散模型 DDPM 算法,对原理以及代码进行了剖析,公式比较多,手推一遍再结合代码分析会有更深的体会。

  • 15
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI周红伟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值