从扩散模型 到分数扩散 到SDE

本文主要是本人看论文和网上视频博客的总结和个人的理解思路,如有错误感谢指正

一、扩散模型

1.介绍

扩散模型属于基于概率分布的生成式模型,终极目的是拟合数据的分布q(x_0)。区别于以往的自编码器模型、GAN模型等,扩散模型不直接对q(x_0)建模。q(x_0)是实际的数据分布,通常非常复杂,过去方法用网络来表示这一分布,即用p_\theta(x)拟合q(x_0)。扩散模型试图通过对N步去噪过程建模来避免直接拟合q(x_0)。(注意:我们用p表示拟合的或者人为设置的分布,用q表示实际的分布)

2.建模过程

        扩散模型假设N步去噪过程是一个马尔可夫过程,即每一步去噪仅依赖于上一步去噪的输出x_{t+1}来得到x_t。因此模型产生输出的过程是从分布q(x_T)采样一个样本x_T,再从给定x_T情况下的条件分布q(x_{T-1}\mid x_T)中采样出x_{T-1},再从给定x_{T-1}情况下的条件分布q(x_{T-2}\mid x_{T-1})中采样出x_{T-2},...,一直重复上面过程直到采样出x_0即为模型的输出。

        用更加形式化的方式表示上面过程。分布q(x_0)可以表示为q(x_0):=\int q(x_{0:T})dx_{1:T},其中q(x_{0:T})表示x_0,x_1,...,x_T的联合分布,x_1,...,x_T可以看作是一系列隐变量。从联合分布中把隐变量积分掉就得到余下那个变量的分布。当然我们不会真的去求一个分布的积分,主要是是我们不知道这个分布。所以实际做法就是上面提到的采样。

        根据条件概率公式P(A,B)=P(A\mid B)P(B)以及上面提到的马尔可夫假设q(x_t\mid x_{t+1},x_{t+2},..,x_{T}) = q(x_t \mid x_{t+1}),联合分布q(x_{0:T})可以表示为q(x_T)q(x_{T-1}\mid x_{T})q(x_{T-2}\mid x_{T-1})...q(x_0 \mid x_1)=q(x_T)\Pi_{t=1}^{T} q(x_{t-1}\mid x_t)。我们现在的目标是如何求出分布q(x_T)以及分布q(x_{t-1}\mid x_t)

        记p(x_T)是人为选定的对分布q(x_T)的近似,p_\theta(x_{t-1}\mid x_t)是用模型拟合的对分布q(x_{t-1}\mid x_t)的近似。于是对联合分布q(x_{0:T})的近似表示为p_\theta(x_{0:T})=p(x_T)\Pi_{t=1}^{T}p_\theta(x_{t-1}\mid x_t)

        很容易理解的一点是,如何去噪取决于加的噪声,也就是我们加上了什么噪声那么我们就要减去一个同样的噪声(看起来像废话,但后面我们会看到,DDPM实际上就是做的就是预测加的噪声)。在DDPM中,加噪声的操作是人为设计的,而且加的噪声的分布是人为选定的。DDPM加噪声的这一步操作采用高斯分布,也就是q(x_t \mid x_{t-1})服从高斯分布N(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_t I)。为什么要选择高斯噪声呢?个人认为主要原因是高斯噪声足够简单,高斯噪声只包含均值和方差两个参数,涉及到高斯分布的运算包括加减乘除和KL散度等都比较简单。因为加噪过程是高斯分布,原论文假设在每一步变化比较小的情况下q(x_t \mid x_{t+1})也符合高斯分布,所以代码实现中\beta序列的变化必需足够缓慢,这导致加噪过程通常都有成百上千步,同样逆过程也需要成百上千步,后续有论文如DDIM优化采样的过程来减少步数。对于t=T,q(x_T \mid x_{T-1}) \sim N(x_T \mid \sqrt{1-\beta_T}x_{T-1},\beta_T I),我们可以看到\beta_T = 1时,分布q(x_T \mid x_{T-1}) \sim N(0, I)x_{T-1}无关,所以我们认为此时q(x_T) = q(x_T \mid x_{T-1}) \sim N(0,I),所以上面的p(x_T)我们可以简单地取标准高斯分布。其中\beta_1,...,\beta_T是事先确定的逐渐增大的参数序列(论文中称作variance shedule 或者beta shedule)。此时联合分布q(x_{0:T})可以表示为q(x_{1:T} \mid x_0) q(x_0),q(x_{1:T} \mid x_0)=q(x_T\mid x_{T-1})q(x_{T-1}\mid x_{T-2})...q(x_1\mid x_0) = \Pi_{t=1}^{T}q(x_t \mid x_{t-1})。其中除了q(x_0)未知外,其他都是已知的。我们知道q(x_0) =q(x_{0:T}) / q(x_{1:T}\mid x_0),带入上面的近似分布p_\theta(x_{0:T})得到对数据集分布q(x_0)的近似拟合p_\theta(x_0) =p_\theta(x_{0:T}) / q(x_{1:T}\mid x_0),注意到此时表达式中已经不包含未知项(要么是模型拟合得到的项,要么是预先定义的项),接下来要考虑的就是如何优化这个拟合的分布(优化其中模型相关的部分)。

3. 优化

(1)要最小化的目标L

        考虑最小化负的对数似然函数E_{q(x_0)}[-log p_\theta(x_0)] = E_{q(x_{0:T})}[-log p_\theta(x_{0:T})],论文中的目标还加上了一项KL项(KL散度大于等于0,论文说法是变分界限),因此有E_{q(x_0)}[-log p_\theta(x_0)] \le E_{q(x_{0:T})}[-log p_\theta(x_{0:T})] + KL(q(x_{0:T}) \mid \mid q(x_0))= E_{q(x_{0:T})}[-log p_\theta(x_{0:T}) + log q(x_{1:T \mid x_0})]=E_{q(x_{0:T})}[-log p(x_T) - \sum_{t\ge 1}log (p_\theta(x_{t-1}\mid x_t) / q(x_t \mid x_{t-1}))] = : L

上面的式子可以进一步写成(利用马尔可夫性和条件概率公式)

L = E_{q(x_{0:T})}[ -log p(x_T) +\sum_{t\ge 2}log(q(x_t\mid x_{t-1},x_0)/p_\theta(x_{t-1}\mid x_t)) - logp_\theta(x_0\mid x_1) / q(x_1\mid x_0)]

其中

q(x_t\mid x_{t-1},x_0)=q(x_t,x_{t-1},x_0) / q(x_{t-1},x_0)= q(x_{t-1}\mid x_t,x_0)q(x_t,x_0)/q(x_{t-1},x_0)= q(x_{t-1}\mid x_t,x_0)q(x_t\mid x_0)/q(x_{t-1}\mid x_0)

进而

\sum_{t\ge 2} log(q(x_t\mid x_{t-1},x_0)/p_\theta(x_{t-1}\mid x_t)) = \sum_{t\ge 2}log (q(x_{t-1}\mid x_t,x_0)/p_{\theta(x_{t-1}\mid x_t)})+\sum_{t\ge 2}log(q(x_t\mid x_0)/q(x_{t-1}\mid x_0))

注意到第二项中的连加可以相消(an / a_{n-1} * a_{n-1} / a_{n-2} ...),于是上面等于下面

= \sum_{t\ge 2}log (q(x_{t-1}\mid x_t,x_0)/p_{\theta(x_{t-1}\mid x_t)})+log(q(x_T\mid x_0)/q(x_1\mid x_0))

进而可将L化简为如下(这里q具体是怎么变化的原论文没讲,一般机器学习中不会真的求期望,而是通过采样的方法解决,每个小的batch size产生的L表达式中期望内的值认为是L的无偏估计,因此可以通过采样加梯度下降实现优化,也就是这里的期望实际可以忽略。注:从数据集取图片出来就是一个采样过程,KL代表KL散度)

L = E_{q_{x_{0:T}}}[q(x_T\mid x_0)/p(x_T) +\sum_{t\ge 2}log(q(x_{t-1}\mid x_t,x_0)/p_\theta(x_{t-1}\mid x_t))-log p_\theta(x_0\mid x_1)]

=E_{q}[KL(q(x_T\mid x_0)\mid\mid p(x_T)) + \sum_{t\ge 2}D_{KL}(q(x_{t-1}\mid x_t,x_0)\mid\mid p_\theta(x_{t-1}\mid x_t))-log_{p_\theta}(x_0\mid x_1)]

上面第一项记为L_T,第二项求和内的部分记为L_{t-1},第三项记作L_0,其中L_T是常数,实际训练时不会使用

(2)q(x_{t-1}\mid x_t,x_0)怎么求

重参数化技巧:从分布N(\mu,\sigma^2I)采样等价于先从分布N(0,I)采样出z,计算出\sigma z + \mu作为分布N(\mu,\sigma^2I)的采样结果。

高斯分布性质:X\sim N(\mu_1,\sigma_1^2),Y\sim N(\mu_2,\sigma_2^2),aX + bY\sim N(a\mu_1+b\mu_2,a^2\sigma_1^2+b^2\sigma_2^2)

首先证明一个性质,任意时刻的x_t可以由x_0,\beta序列来近似的产生。

\alpha_t=1-\beta_t,\bar \alpha_t = \alpha_t * \alpha_{t-1}...*\alpha_1 另外z_1,z_2,... \sim N(0,I)

q(x_t \mid x_{t-1}) \sim N(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_t I)和重参数化技巧,从q(x_t \mid x_{t-1})采样等价于

x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1-\alpha_t}z_{t-1}

\sqrt{\alpha_t}( \sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_{t-1}}z_{t-2} ) + \sqrt{1-\alpha_t}z_{t-1}

=\sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\bar z_{t-2}

=...

=\sqrt{\bar \alpha_t}x_0+\sqrt{1-\bar \alpha_t}z \quad *(1)

从而q(x_{t}\mid x_0)=N(x_t;\sqrt{\bar \alpha_t}x_0,(1-\bar \alpha_t)I) \quad* (2)

前面我们提到论文假设q(x_t \mid x_{t+1})是高斯分布,我们认为这里q(x_{t-1}\mid x_t,x_0)也是高斯分布,即

q(x_{t-1}\mid x_t,x_0) \sim N(x_{t-1};\tilde\mu(x_t,x_0),\tilde\beta_t I)

由条件概率公式,有

t = q(x_{t-1}\mid x_t,x_0)=q(x_t\mid x_{t-1},x_0) q(x_{t-1}\mid x_0) /q(x_t \mid x_0)

=q(x_t\mid x_{t-1})q(x_{t-1}\mid x_0)/q(x_t\mid x_0)   (马尔可夫性)

高维高斯分布p(x)=\frac{1}{\sqrt{(2\pi)^Ndet(\Sigma)}}exp(-\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-u)),因为已经假设结果为高斯分布,所以我们忽略exp前面的系数来求均值和方差,注意前面方差(aI)^{-1}=\frac{1}{a}I

t \propto exp(-\frac{1}{2}( (x_t-\sqrt{\alpha_t}x_{t-1})^2/(1-\alpha_t) + (x_{t-1}-\sqrt{\bar\alpha_{t-1}}x_0)^2/(1-\bar\alpha_{t-1}) -(x_t-\sqrt{\bar\alpha_t}x_0)^2/(1-\bar\alpha_t) ))

=exp(-\frac{1}{2}( (\alpha_t/(1-\alpha_t)+1/(1-\bar\alpha_t))x_{t-1}^2 - (2\sqrt{\alpha_t}x_t/(1-\alpha_t)+2\sqrt{\bar\alpha_{t-1}}x_0/(1-\bar\alpha_t)) x_{t-1+C(x_t,x_0)} ))

因此\tilde \beta= 1/(\alpha_t/\beta_t+1/(1-\bar\alpha_{t-1}))=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t

\tilde \mu_t(x_t,x_0)=(\sqrt{\alpha_t}/\beta_t+\sqrt{\bar\alpha_{t-1}}/(1-\bar\alpha_t)x_0) / (\alpha_t/\beta_t+1/(1-\bar\alpha_{t-1}))

=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}x_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0

由*(1)公式知x_0= 1/\sqrt{\bar\alpha_t}(x_t-\sqrt{1-\bar\alpha_t}z_t),代入上面得

\tilde \mu_t=1/\sqrt{\alpha_t}(x_t-\beta_t/\sqrt{1-\bar\alpha_t}z_t)

(3) L_{t-1}化简

在上面提到L_{t-1}=D_{KL}(q(x_{t-1}\mid x_t,x_0)\mid\mid p_\theta(x_{t-1}\mid x_t))

因为前面提到论文假设每步改变很小得情况下认为q(x_t \mid x_{t+1})符合高斯分布。

所以可以有如下表示p_\theta(x_{t-1}\mid x_t) \sim N(x_{t-1}\mid \mu_\theta(x_t,t),\sigma_t^2I)   *(3)

论文中固定\sigma_t^2=\beta_t或者上面的\tilde \beta_t,当然也可以改变模型通道数或者用另一个模型来预测该方差。

官方代码中的预测v目标即指需要预测该方差,只不过不是直接预测方差。官方代码提供了固定方差还是预测方差的选项。

多元高斯分布KL散度公式如下:

KL(N(x\mid\mu_1,\Sigma_1)\mid\mid N(x\mid \mu_2,\Sigma_2))=

\frac{1}{2}[ log (\mid \Sigma_2\mid /\mid \Sigma_1\mid) -K +tr(\Sigma_2^{-1}\Sigma_1) +(\mu_1-\mu_2)^T\Sigma_2^{-1}(\mu_1-\mu_2) ]

因此

L_{t-1} = \frac{1}{2} [C+(\lVert \tilde\mu(x_t,x_0) - \mu_\theta(x_t,x_0)\rVert^2) / (\sigma_t^2)]

最优情况是\tilde\mu(x_t,x_0) = \mu_\theta(x_t,x_0)

论文再次利用重参数化并假定上面等式左右两边有相同形式,从而将任务从预测方差转到预测噪声,即

\mu_\theta(x_t,t)=1/\sqrt{\alpha_t}(x_t-\beta_t/\sqrt{1-\bar\alpha_t}\epsilon_\theta(x_t,t))   * (4)

从而L_{t-1}=C_1+C_2\lVert \epsilon - \epsilon_\theta(x_t,t) \rVert^2

去除两个常系数即得到论文中提到的最终的目标损失。这里看到,模型最终转为预测x_t中加入的随机量的噪声(高斯噪声)

4.采样

(1)论文采样方式

论文算法的采样方式由*(3)和*(4)公式得到分布不断采样得到。初始从标准高斯分布采样得到x_T

不断用重参数化从分布p_\theta(x_{t-1}\mid x_t)采样,即如下公式

x_{t-1}=1/\sqrt{\alpha_t}(x_t-\beta_t/\sqrt{1-\bar\alpha_t}\epsilon_\theta(x_t,t)) +\sigma_t z

最终得到x_0即输出

(2)直接采样

论文提到可以利用公式x_0= 1/\sqrt{\bar\alpha_t}(x_t-\sqrt{1-\bar\alpha_t}z_t)直接进行采样,其中z_t=\epsilon_\theta(x_t,t)

但是这样采样得到的效果不好

二、分数扩散模型待写

  • 25
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值