最通俗易懂的扩散模型

当前主要有四大生成模型:生成对抗模型,变分自动编码器、流模型和扩散模型。扩散模型(diffusion models)是当前深度生成模型中新SOTA。扩散模型在图片生成任务中超越了原SOTA:GAN,并且在诸多应用领域都有出色的表现,如计算机视觉,NLP、波形信号处理、多模态建模、分子图建模、时间序列建模、对抗性净化等。此外,扩散模型与其他研究领域有着密切的联系,如稳健学习、表示学习、强化学习。
在这里插入图片描述

Variational AutoEncoder (VAE)

要讲扩散模型,从 VAE 和 GAN 说起,它们都是从隐变量 Z Z Z 生成目标数据 X X X。它们假设隐变量服从某种常见的概率分布(比如正态分布),然后希望训练一个模型 X = g ( Z ) X = g(Z) X=g(Z),这个模型将原来的概率分布映射到训练集的概率分布,也就是分布的变换。注意,VAE 和 GAN 的本质都是概率分布的映射。大致思路如下:
图片来源:https://zhuanlan.zhihu.com/p/34998569
换句话说,就是先在某种分布上随机生成一组隐变量,然后这个隐变量会经过一个生成器生成一组目标数据。VAE 和 GAN 都希望这组数据的分布 X ^ \hat{X} X^ 和目标分布 X X X 尽可能接近。

是不是听上去很 work?但是这种方法本质上是难以 work 的,因为“尽量接近”并没有一个确定的关于 X X X X ^ \hat{X} X^ 的相似度的评判标准。换句话说,这种方法的难度在于必须去猜测“它们的分布相等吗”,而缺少真正可解释的价值判断。那 KL 散度不行吗?当然不行,KL 散度是针对两个已知的概率分布求相似度的,而 X ^ \hat{X} X^ X X X 的概率分布目前都是未知的。

GAN 的粗暴做法就是直接学习这个度量标准,但是依然缺乏可解释性,非常不优雅。VAE 的做法要优雅很多,我们先来理解 VAE 是如何做的,理解了 VAE 以后再去理解 Diffussion 就很自然了。

到底什么是生成模型?

那生成模型到底是个啥?假设有一批样本 X X X,想要从 X X X 学习到它的分布 p ( X ) p(X) p(X),这样就能同时学习到没被采样到的数据了,用这个分布 p ( X ) p(X) p(X) 就能随意采样,然后获得生成结果。但是这个分布九转回肠,根本不可能直接获得。所以绕个弯,整一个隐变量 Z Z Z,这东西可以生成 X X X。不妨假设 Z Z Z 满足正态分布,那就可以先从正态分布里面随机取一个 Z Z Z,然后用 Z Z Z X X X 的关系算出 p ( X ) p(X) p(X)。这里需要用到数学公式:
p ( X ) = ∑ Z p ( X ∣ Z ) p ( Z ) p(X) = \sum_Z p(X|Z)p(Z) p(X)=Zp(XZ)p(Z)
换句话说,就是不直接求 p ( X ) p(X) p(X),而是造一个别的变量(隐变量),获得隐变量和 X X X 的关系,也就能得到 p ( X ) p(X) p(X)上式中的 p ( X ∣ Z ) p(X|Z) p(XZ) 称为后验分布,p(Z) 为先验分布。

VAE

VAE 的核心就是不仅假设 p ( Z ) p(Z) p(Z) 是正态分布,而且假设每个 p ( X k ∣ Z ) p(X_k|Z) p(XkZ) 也是正态分布。什么意思呢?因为 X X X 是一组采样,其实可以表示成 X = { X 1 , X 2 , … , X k } X = \{X_1, X_2, \dots, X_k \} X={X1,X2,,Xk},而我们想要针对每个 X k X_k Xk 获得一个专属于它和 Z Z Z 的一个正态分布。换句话说,有 k k k X X X 样本,就有 k k k 个正态分布 p ( X k ∣ Z ) p(X_k|Z) p(XkZ)。其实也很好理解,每一个采样点当然都需要一个相对 Z Z Z 的分布,因为没有任何两个采样点是完全一致的。

那就要想方设法获得这 k k k 个正态分布了,怎么搞?学!拟合!但是这里需要注意,这里的拟合与 GAN 不同,本质上是在学 X k X_k Xk Z Z Z 的关系,而非学习比较 X X X X ^ \hat{X} X^ 的标准。

OK,现在问一个小学二年级就知道的问题,已知是正态分布,学什么才能确定这个是正太分布?没错,均值和方差。怎么学?有数据呀! Z Z Z 不是你自己假设的吗, X k X_k Xk 是已知的,那就用这两去学习均值和方差。

好,现在我们已经学习到了这 k k k 个正态分布。那就好说了,直接从 p ( Z ∣ X k ) p(Z|X_k) p(ZXk) 里面采样一个 Z k Z_k Zk,学一个生成器,就能获得 X k = g ( Z k ) X_k = g(Z_k) Xk=g(Zk) 了。那接下来只需要最小化方差 D 2 ( X ^ k , X k ) D^2(\hat{X}_k, X_k) D2(X^k,Xk) 就行。如下图所示
在这里插入图片描述

仔细理解的时候有没有发现一个问题?为什么在文章开头,强调了没法直接比较 X X X X ^ \hat{X} X^ 的分布,而在这里,我们认为可以直接比较?注意,这里的 Z k Z_k Zk 是专属于(针对于) X k X_k Xk 的隐变量,那么和 X ^ k \hat{X}_k X^k 本身就有对比关系,因此右边的蓝色方块的 “生成器”,是一一对应的生成。

另外,大家可以看到,均值和方差本质上都是 encoder,也就是说,VAE 其实利用了两个 encoder 去分别学习均值和方差。

VAE 的 VAriational 到底是个啥

这里还有一个非常重要的问题(对于初学者而言可能比较困难,需要反复思考):由于我们通过最小化 D 2 ( X k ∣ X ^ k ) D^2(X_k|\hat{X}_k) D2(XkX^k) 来训练右边的生成器,最终模型会逐渐使得 X k X_k Xk X ^ k \hat{X}_k X^k 趋于一致。但是注意,因为 Z k Z_k Zk 是重新采样过的,而不是通过均值和方差 encoder 学出来的,这个生成器的输入 Z Z Z 是有噪声的。但是仔细思考一下,这个噪声的大小其实就用方差来度量。为了使分布的学习尽量接近,我们希望噪声越小越好,所以我们会尽量使得方差趋近于 0。

但是方差不能为 0,因为我们好像给模型一些训练难度。如果方差为 0,模型永远只需要学习高斯分布的均值,这样就丢失了随机性,VAE 就变成了 AE。这就是为什么 VAE 要在 AE 前面加一个 Variational:我们希望方差能够持续存在,从而带来噪声!那如何解决这个问题呢?其实保证有方差就行,但是 VAE 给出了一个优雅的答案:不仅需要保证有方差,还要让所有 p ( Z ∣ X ) p(Z|X) p(ZX) 趋于正态分布 N ( 0 , I ) N(0, I) N(0,I)!为什么要这么做呢?这里需要一个小小的数学推导:
P ( Z ) = ∑ Z p ( Z ∣ X ) p ( X ) = N ( 0 , I ) p ( x ) = N ( 0 , I ) ∑ X p ( X ) = N ( 0 , I ) P(Z) = \sum_Z p(Z|X)p(X) = N(0, I)p(x) = N(0, I)\sum_X p(X) = N(0, I) P(Z)=Zp(ZX)p(X)=N(0,I)p(x)=N(0,I)Xp(X)=N(0,I)
这条式子想必大家都看得懂,看不懂也没关系……关键是结论:如果所有 p ( Z ∣ X ) p(Z|X) p(ZX) 都趋近于 N ( 0 , I ) N(0, I) N(0,I),那么我们可以保证 p ( Z ) p(Z) p(Z) 也趋近于 N ( 0 , I ) N(0, I) N(0,I),从而实现先验的假设,这样就形成了一个闭环!太优雅了,那怎么保证让所有 p ( Z ∣ X ) p(Z|X) p(ZX) 趋近于 N ( 0 , I ) N(0,I) N(0,I) 呢?加 Loss 呗,具体的 Loss 推导就不深入了,用到了很多数学知识。到此位置,我们把 VAE 进一步画成:
在这里插入图片描述

VAE 的本质

现在再回顾 VAE 到底做了啥。VAE 就是在 AE 的基础上对均值的 encoder 添加高斯噪声(正态分布的随机采样),使得 decoder(生成器)有噪声鲁棒性;为了防止噪声消失,将所以 p ( Z ∣ X ) p(Z|X) p(ZX) 趋近于标准正太分布,将 encoder 的均值尽量降为 0,而将方差尽量保持住。这样一来,当 decoder 训练不好的时候,整个体系就可以降低噪声;当 decoder 逐渐拟合的时候,就会增加噪声。

Diffusion Model(扩散模型)

到此为止,你已经理解了扩散模型的所有基础。现在我们就站在 VAE 的基础上讲讲扩散模型。目前的教程实在台数学了,其实可以用更加通俗的语言讲清楚。从本质上说,Diffusion 就是 VAE 的升级版。

VAE 有一个做了好几年的核心问题。大家思考一下,上面的 VAE 中,变分后验 P ( X ∣ Z ) P(X|Z) P(XZ) 是怎么获得的?是学出来的!用 D 2 ( X k ∣ X ^ k ) D^2(X_k|\hat{X}_k) D2(XkX^k) 当 Loss,去学这个 p ( X ∣ Z ) p(X|Z) p(XZ)。学这个变分后验就有五花八门的方法了,除了上面说的拟合法,还有用纯数学来做的,甚至有用 BERT 这种预训练语言模型(PLM)来做的。但是无论如何都跳不出 VAE 这个框架:必须想办法设计一个生成器 g ( Z ) = X g(Z) = X g(Z)=X,使得变分后验分布 p ( X ∣ Z ) p(X|Z) p(XZ) 尽量真实。这种方法的问题在于,这个变分后验 p ( X ∣ Z ) p(X|Z) p(XZ) 的表达能力与计算代价不可兼得。换句话说,简单的变分后验表达并不丰富(例如数学公式法),而复杂的变分后验计算过程(例如 PLM 法)。

现在回头看看 GAN 做了啥。前面提到过,GAN 非常简单粗暴,没有任何 encoder,直接训练生成器,唯一的难度在于判别器(就是下图这个“它们的分布相等吗”的东西)不好做。
在这里插入图片描述
好了,聪明的你已经知道我要说什么了。Diffusion 本质就是借鉴了 GAN 这种训练目标单一的思路和 VAE 这种不需要判别器的隐变量变分的思路,糅合一下,发现还真 work 了……下面让我们看看到底是怎么糅合的。为什么我们糅合甚至还没传统方法好,大佬糅合出个 diffusion?

Diffusion 的核心

前面提到过,VAE 的最大问题就是这个变分后验。在 VAE 中,我们先定义了右边蓝色的生成器 X = g ( Z ) X = g(Z) X=g(Z),再学一个变分后验 p ( X ∣ Z ) p(X|Z) p(XZ) 来适配这个生成器。能不能反一下,先定义一个变分后验再学一个生成器呢?
在这里插入图片描述

如果你仔细看了上面 VAE 的部分,相信你已经有思路了。VAE 的生成器,是将标准高斯分布映射到数据样本(自己定义的)。VAE 的后验分布,是将数据样本映射到标准高斯分布(学出来的)。那反过来,我想要设计一种方法 A,使得 A 用一种简单的“变分后验”将数据样本映射到标准高斯分布(自己定义的),并且使得 A 的生成器,将标准高斯分布映射到数据样本(学出来的)。注意,因为生成器的搜索空间大于变分后验,VAE 的效率远不及 A 方法:因为 A 方法是学一个生成器(搜索空间大),所以可以直接模仿这个“变分后验”的每一小步

好,现在我告诉你,这个 A 方法就是扩散模型的核心思路:定义一个类似于“变分后验”的数据样本到高斯分布的映射,然后学一个生成器,这个生成器模仿我们定义的这个映射的每一小步

Diffusion Model 的 Diffusion 到底是个啥?

接触 diffusion 的你肯定知道马尔科夫链!举个例子说明马尔科夫链。马尔科夫链可以用很简单的话来解释:它就像是一种随机游走,但是有规则可循。想象你在一个由许多房间组成的大厦里,每个房间都有一扇门通向其他房间。你从当前房间出发,每次随机选择一扇门进入下一个房间。在每个房间里,你都会做出一个随机的选择,决定下一个去哪个房间。这就好像是在进行一次“随机游走”。
而马尔科夫链的特殊之处在于,你在选择下一个房间时,并不是完全随机的,而是受到当前所在房间的影响。也就是说,你下一个去哪个房间的概率取决于你当前所在的房间,而与你之前所在的房间无关。这就好像是在进行一种“有规则的随机游走”。简单来说,马尔科夫链就是一种具有特定规则的随机过程,其中未来状态的概率只依赖于当前状态,而与过去状态无关。这种特性使得马尔科夫链在建模许多随机过程时非常有用,比如在模拟随机漫步、网络传播模型等方面有广泛的应用。

这东西不仅 diffusion 有,各种怪异的算法里面也出现了。为什么用它?因为它的一个关键性质:平稳性。一个概率分布如果随时间变化,那么在马尔科夫链的作用下,它一定会趋于某种平稳分布(例如高斯分布)。只要终止时间足够长,概率分布就会趋近于这个平稳分布。

这个逐渐逼近的过程被称为前向过程(forward process)。**注意,这个过程的本质还是加噪声!**试想一下为什么……其实和 VAE 非常相似,都是随机采样!马尔科夫链每一步的转移概率,本质上都是在加噪声。这就是扩散模型中“扩散”的由来:噪声在马尔科夫链演化的过程中,逐渐进入 diffusion 体系。随着时间的推移,加入的噪声(加入的溶质)越来越少,而体系中的噪声(这个时刻前的所有溶质)逐渐在 diffusion 体系中扩散,直至均匀。看下面的图,你应该就恍然大悟了:

在这里插入图片描述

现在想想,为什么要用马尔科夫链。我们把问题重述一下:为什么我们创造一个稳定分布为高斯分布的马尔科夫链,对于生成器模仿我们定义的某个映射的每一小步有帮助呢?这里你肯定想不出来,不然你能发明 diffusion model——答案是。基于马尔科夫链的前向过程,其每一个 epoch 的逆过程都可以近似为高斯分布。

懵了吧,我也懵了。真正的推导发了好几篇 paper,都是些数学巨佬的工作,不得不感叹基础科学的力量……相关工作主要用的是 SDE(随机微分方程),我们这里不深入,但是需要理解大致的思路,如下图所示:
在这里插入图片描述
下面是前向过程,上面是反向过程。前向过程通过马尔科夫链的相转移概率不断加入噪声,从右边的采样数据到左边的标准高斯分布;反向过程通过 SDE 来“抄袭”对应正向过程的那一个 epoch 的行为(其实每一步都不过是高斯分布),从而逐渐学习到对抗噪声的能力。高斯分布是一种很简单的分布,运算量小,这一点是 diffusion 快的重要原因。

Diffusion 的本质

现在回头看看 Diffusion 到底做了个啥工作,我们着重看 VAE 和 Diffusion 的区别:
在这里插入图片描述
可以很清晰的认识到,VAE 的本质是一个基于梯度的 encoder-decoder 架构,encoder 用来学习高斯分布的均值和方差,decoder 用变分后验来学习生成能力,而将标准高斯分布映射到数据样本(自己定义的)。而扩散模型的本质是一个 SDE/Markov 架构,虽然也借鉴了神经网络的前向/反向传播概念,但是并不基于可微的梯度,属于数学层面上的创新。两者都定义了高斯分布 Z Z Z 作为隐变量,但是 VAE 将 Z Z Z 作为先验条件(变分先验),而 Diffusion 将 Z Z Z 作为类似于变分后验的马尔科夫链的平稳分布。

代码

下面我们通过代码在介绍一下 Diffusion。它包括两个过程:前向过程和反向过程。其中前向过程又称为扩散过程。无论是前向过程还是反向过程都是一个参数化的马尔科夫链,其中反向过程可用于生成数据样本(它的作用类似 GAN 的生成器,只不过 GAN 生成器会有维度变化,而 DDPM 的反向过程没有维度变化)。
在这里插入图片描述

  • x 0 x_0 x0 x T x_T xT 为逐步加噪的前向过程,噪声是从高斯分布中采样的(已知的),该过程从原始图片逐步加噪至成为纯噪声。
  • X T X_T XT x 0 x_0 x0 为将随机噪声还原为输入的过程。该过程需要学习一个去噪过程,直至还原一张图片。
前向过程

前向过程是加噪的过程,前向过程中图像 x t x_t xt 只和上一时刻 x t − 1 x_{t-1} xt1 有关,表示为:
q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q ( x t ∣ x t − 1 ) = N ( x t , 1 − β t x t − 1 , β t I ) q(x_{1:T}|x_0) = \prod_{t=1}^T q(x_t|x_{t-1}) \\ q(x_t|x_{t-1}) = N(x_t, \sqrt{1-\beta_t} x_{t-1}, \beta_t I) q(x1:Tx0)=t=1Tq(xtxt1)q(xtxt1)=N(xt,1βt xt1,βtI)
其中不同的 t t t 是预先定义好的逐渐衰减的。可以是 Linear,cosine 等,满足 β 1 < β 2 < ⋯ < β T \beta_1 < \beta_2 < \dots < \beta_T β1<β2<<βT β t \beta_t βt 的生成代码如下:

def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)


def cosine_beta_schedule(time_steps, s=0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = time_steps + 1
    x = np.linspace(0, time_steps, steps).astype(np.float32)
    alphas_cumprod = np.cos(((x / time_steps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return np.clip(betas, 0, 0.999)

根据上式,可以通过重参数化采样得到 x t x_t xt。令 ϵ   N ( 0 , I ) \epsilon ~ N(0, I) ϵ N(0,I) α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt α t ˉ = ∏ i = 1 T α i \bar{\alpha_t} = \prod_{i=1}^T \alpha_i αtˉ=i=1Tαi,经过推导,可以得出 x t x_t xt x 0 x_0 x0 的关系:
q ( x t ∣ x 0 ) = N ( x t ; α t ˉ x 0 , ( 1 − α ˉ ) I ) q(x_t|x_0) = N(x_t; \sqrt{\bar{\alpha_t}}x_0, (1 - \bar{\alpha})I) q(xtx0)=N(xt;αtˉ x0,(1αˉ)I)

逆向过程

逆向过程是去噪的过程。如果得到逆向过程 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt),就可以通过随机噪声 x T x_T xT 逐步还原出一张图像。DDPM 使用神经网络 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt) 拟合逆向过程 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt)

根据 q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ∣ μ t ~ ( x t , x 0 ) , β t ~ I ) q(x_{t-1}|x_t, x_0) = N(x_{t-1}|\tilde{\mu_t}(x_t, x_0), \tilde{\beta_t} I) q(xt1xt,x0)=N(xt1μt~(xt,x0),βt~I) 可以推导出:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ∣ μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_{\theta}(x_{t-1}|x_t) = N(x_{t-1}| \mu_{\theta}(x_t, t), \Sigma_{\theta}(x_t, t)) pθ(xt1xt)=N(xt1μθ(xt,t),Σθ(xt,t))
DDPM 论文中不计算方差,通过神经网络拟合均值 μ θ \mu_{\theta} μθ,从而得到 x t − 1 x_{t-1} xt1
μ θ = 1 α t ( x t − 1 − α t 1 − α t ˉ ϵ θ ( x t , t ) ) \mu_{\theta} = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha_t}}} \epsilon_{\theta(x_t, t)}) μθ=αt 1(xt1αtˉ 1αtϵθ(xt,t))
因为 t t t x t x_t xt 已知,只需要使用神经网络拟合 ϵ θ ( x t , t ) \epsilon_{\theta(x_t, t)} ϵθ(xt,t)

UNet 职责

无论在前向过程还是后向过程,UNet 的职责都是根据当前的样本和时间 t t t 删噪声,也就是 UNet 实现 ϵ θ ( x t , t ) \epsilon_{\theta(x_t, t)} ϵθ(xt,t) 的预测,整个训练过程其实就是在训练 UNet 的参数。

Gaussion Diffusion 职责

前向过程:从 1 到 T 的时间采样一个时间 t t t,生成一个随机噪声加到图片上,从 UNet 获取预测噪声,计算损失后更新 UNet 梯度。
后向过程:先从正态分布中随机采样和训练一样大小的纯噪声图片,从 T-1 到 0 逐步重复以下步骤:从 x t x_t xt 还原 x t − 1 x_{t-1} xt1

训练过程
在这里插入图片描述

算法1:训练过程

  • 从数据中抽一个样本;
  • 从 1~T 中司机抽取一个时间 t t t
  • x 0 x_0 x0 t t t 传给 GaussionDiffusion,GaussionDiffusion 采样一个随机噪声,加到 x 0 x_0 x0,形成 x t x_t xt,然后将 x t x_t xt t t t 放入 UNet,UNet 根据 t t t 生成正弦位置编码和 x t x_t xt 结合,UNet 预测加的这个噪声,并返回噪声,GaussionDiffusion 计算该噪声和随机噪声的损失;
  • 将 UNet 预测的噪声与之前 GaussionDiffusion 采样的随机噪声求 L2 损失,计算梯度,更新权重;
    重复以上步骤,直到网络 UNet 训练完成。

训练步骤中每个模块的交互图如下:
在这里插入图片描述
算法2:采样

  • 从标准正态分布中采样出 x T x_T xT;
  • T , T − 1 , … , 2 , 1 T, T-1, \dots, 2, 1 T,T1,,2,1 依次重复以下步骤:
    (1) 从标准正态分布中采样 z z z,为重参数化做准备;
    (2) 根据模型求出 ϵ θ \epsilon_{\theta} ϵθ ,计算样本 noise 的均值,结合 x t x_t xt - pred_noise 和 z z z,利用重参数化技术得到 x t − 1 x_{t-1} xt1
  • 循环结束,返回 x 0 x_0 x0
    在这里插入图片描述

结合代码(MindSpore 版本)讲解

代码主要分为以下几块:UNet,GaussionDiffusion、Trainer

UNet 相关模块

结构如下:
在这里插入图片描述

正弦位置编码

DDPM 的每步训练是随机采样一个时间,为了让网络知道当前处理的是一系列去噪过程中的哪一个 step,我们需要将时间 t t t 的编码并传入网络之中,DDPM 使用的 UNet 是 time-condition UNet。类似于 Transformer的positional embedding,DDPM 采用正弦位置编码,既需要位置编码有界又需要两个时间步长之间的距离与句子长度无关。为了满足这两点标准,一种思路是使用有界的周期性函数,而简单的有界周期性函数很容易想到 sin 和 cos 函数。

class SinusoidalPosEmb(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        half_dim = dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = np.exp(np.arange(half_dim) * - emb)
        self.emb = Tensor(emb, mindspore.float32)
        self.Concat = _get_cache_prim(ops.Concat)(-1)

    def construct(self, x):
        emb = x[:, None] * self.emb[None, :]
        emb = self.Concat((ops.sin(emb), ops.cos(emb)))
        return emb

Attention 模块

DDPM 的 UNet 有 ResidualBlock 和 Attention Module

class Attention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = _get_cache_prim(Conv2d)(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
        self.to_out = _get_cache_prim(Conv2d)(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
        self.map = ops.Map()
        self.partial = ops.Partial()
        self.bmm = BMM()
        self.split = ops.Split(axis=1, output_num=3)
        self.softmax = ops.Softmax(-1)

    def construct(self, x):
        b, c, h, w = x.shape
        qkv = self.split(self.to_qkv(x))
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
        q = q * self.scale
        sim = self.bmm(q.swapaxes(2, 3), k)
        attn = self.softmax(sim)
        out = self.bmm(attn, v.swapaxes(2, 3))
        out = out.swapaxes(-1, -2).reshape((b, -1, h, w))
        return self.to_out(out)


class Residual(nn.Cell):
    """残差块"""
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def construct(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

GaussionDiffusion 模块

首先定义相关的概率值,与公式对应:

        self.betas = betas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
        self.sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))
        self.log_one_minus_alphas_cumprod = Tensor(np.log(1. - alphas_cumprod))
        self.sqrt_recip_alphas_cumprod = Tensor(np.sqrt(1. / alphas_cumprod))
        self.sqrt_recipm1_alphas_cumprod = Tensor(np.sqrt(1. / alphas_cumprod - 1))

        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

        self.posterior_variance = Tensor(posterior_variance)

        self.posterior_log_variance_clipped = Tensor(
            np.log(np.clip(posterior_variance, 1e-20, None)))
        self.posterior_mean_coef1 = Tensor(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.posterior_mean_coef2 = Tensor(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))

        p2_loss_weight = (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))\
                          ** - p2_loss_weight_gamma
        self.p2_loss_weight = Tensor(p2_loss_weight)

计算损失

基于 UNet 预测出 noise,使用预测的 noise 和真实的 noise 计算损失:

	def p_losses(self, x_start, t, noise, random_cond):
	    # 生成的真实noise
	    x = self.q_sample(x_start=x_start, t=t, noise=noise)
	
	    # if doing self-conditioning, 50% of the time, predict x_start from current set of times
	    if self.self_condition:
	        if random_cond:
	            _, x_self_cond = self.model_predictions(x, t)
	            x_self_cond = ops.stop_gradient(x_self_cond)
	        else:
	            x_self_cond = ops.zeros_like(x)
	    else:
	        x_self_cond = ops.zeros_like(x)
	
	    # model_out为基于Unet预测的pred_noise,此处self.model为Unet,ddpm默认预测目标是pred_noise。
	    model_out = self.model(x, t, x_self_cond)
	
	    if self.objective == 'pred_noise':
	        target = noise
	    elif self.objective == 'pred_x0':
	        target = x_start
	    elif self.objective == 'pred_v':
	        v = self.predict_v(x_start, t, noise)
	        target = v
	    else:
	        target = noise
		
		# 计算损失值
	    loss = self.loss_fn(model_out, target)
	    loss = loss.reshape(loss.shape[0], -1)
	    loss = loss * extract(self.p2_loss_weight, t, loss.shape)
	    return loss.mean()

采样

输出 x 0 x_0 x0,也就是原始图像,当 sampling_time_steps< time_steps,用下方函数:

def ddim_sample(self, shape, clip_denoise=True):
    batch = shape[0]
    total_timesteps, sampling_timesteps, = self.num_timesteps, self.sampling_timesteps
    eta, objective = self.ddim_sampling_eta, self.objective

    # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
    times = np.linspace(-1, total_timesteps - 1, sampling_timesteps + 1).astype(np.int32)
    # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
    times = list(reversed(times.tolist()))
    time_pairs = list(zip(times[:-1], times[1:]))

	# 采样第一次迭代,Unet输入img为随机采样
    img = np.random.randn(*shape).astype(np.float32)
    
    x_start = None

    for time, time_next in tqdm(time_pairs, desc='sampling loop time step'):
        # time_cond = ops.fill(mindspore.int32, (batch,), time)
        time_cond = np.full((batch,), time).astype(np.int32)
        x_start = Tensor(x_start) if x_start is not None else x_start
        self_cond = x_start if self.self_condition else None
        predict_noise, x_start, *_ = self.model_predictions(Tensor(img, mindspore.float32),
                                                            Tensor(time_cond),
                                                            self_cond,
                                                            clip_denoise)
        predict_noise, x_start = predict_noise.asnumpy(), x_start.asnumpy()
        if time_next < 0:
            img = x_start
            continue

        alpha = self.alphas_cumprod[time]
        alpha_next = self.alphas_cumprod[time_next]

        sigma = eta * np.sqrt(((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)))
        c = np.sqrt(1 - alpha_next - sigma ** 2)

        noise = np.random.randn(*img.shape)

        img = x_start * np.sqrt(alpha_next) + c * predict_noise + sigma * noise

    img = self.unnormalize(img)

    return img

Trainer 训练器

data_iterator 中每次取出的数据集就是一个 batch_size 的大小,每训练一个 batch,step 加 1。

指数移动平均
DDPM 的 trainer 采用 ema(指数移动平均)优化,ema 不参与训练,只参与推理,比对变量直接赋值而言,移动平均得到的值在图像上更加平缓光滑,抖动性更小。具体代码参考代码仓中 ema.py。

参数:

  • num_timesteps:原理中提到的T,扩散的步数;
  • train_num_steps:训练的总步数,每个 step 取用一个 batch 的数据。
print('training start')
        with tqdm(initial=self.step, total=self.train_num_steps, disable=False) as pbar:
            total_loss = 0.
            for (img,) in data_iterator:
                model.set_train()
                # 随机采样time向量
                time_emb = Tensor(
                    np.random.randint(0, num_timesteps, (img.shape[0],)).astype(np.int32))
                noise = Tensor(np.random.randn(*img.shape), mindspore.float32)
                # 返回损失、计算梯度、更新梯度
                self_cond = random.random() < 0.5 if self.self_condition else False
                loss = train_step(img, time_emb, noise, self_cond)

                # 损失累加
                total_loss += float(loss.asnumpy())

                self.step += 1
                if self.step % gradient_accumulate_every == 0:
                    # ema和model的参数同步更新
                    self.ema.update()
                    pbar.set_description(f'loss: {total_loss:.4f}')
                    pbar.update(1)
                    total_loss = 0.

                accumulate_step = self.step // gradient_accumulate_every
                accumulate_remain_step = self.step % gradient_accumulate_every
                if self.step != 0 and accumulate_step % self.save_and_sample_every == 0\
                        and accumulate_remain_step == 0:

                    self.ema.set_train(False)
                    self.ema.synchronize()
                    batches = num_to_groups(self.num_samples, self.batch_size)
                    all_images_list = list(map(lambda n: self.ema.online_model.sample(batch_size=n),
                                               batches))
                    self.save_images(all_images_list, accumulate_step)
                    self.save(accumulate_step)
                    self.ema.desynchronize()

                if self.step >= gradient_accumulate_every * self.train_num_steps:
                    break

        print('training complete')


参考文献:
如何通俗理解扩散模型?
一文读懂扩散模型,DDPM原理+代码解读

代码链接:
https://openi.pcl.ac.cn/drizzlezyk/ddpm2
https://xihe.mindspore.cn/projects/drizzlezyk/DDPM
https://github.com/drizzlezyk/DDPM-MindSpore

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值