深度学习——DDPM

前置数学知识

1、先验概率和后验概率

先验概率:根据以往经验和分析得到的概率,它往往作为“由因求果”问题中的“因”出现,如 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xtxt1)

后验概率:指在得到“结果”的信息后重新修正的概率,是“执果寻因”问题中的“因", 如 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt)

2、条件概率:设 A A A B B B为任意两个事件,若 P ( A ) > 0 P(A)>0 P(A)>0,称在已知事件 A A A发生的条件下,事件 B B B发生的概率为条件概率,记为 P ( B ∣ A ) P(B|A) P(BA)
P ( B ∣ A ) = P ( A , B ) P ( A ) P(B|A)=\frac{P(A,B)} {P(A)} P(BA)=P(A)P(A,B)

3、乘法公式:
P ( A , B ) = P ( B ∣ A ) P ( A ) P(A,B)=P(B|A)P(A) P(A,B)=P(BA)P(A)

4、乘法公式一般形式:
P ( A , B , C ) = P ( C ∣ B , A ) P ( B , A ) = P ( C ∣ B , A ) P ( B ∣ A ) P ( A ) P(A,B,C)=P(C|B,A)P(B,A)=P(C|B,A)P(B|A)P(A)\\ P(A,B,C)=P(CB,A)P(B,A)=P(CB,A)P(BA)P(A)

5、贝叶斯公式:
P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) P(A|B)=\frac{P(B|A)P(A)}{P(B)} P(AB)=P(B)P(BA)P(A)
6、多元贝叶斯公式:
P ( A ∣ B , C ) = P ( A , B , C ) P ( B , C ) = P ( B ∣ A , C ) P ( A , C ) P ( B , C ) = P ( B ∣ A , C ) P ( A ∣ C ) P ( C ) P ( B ∣ C ) P ( C ) = P ( B ∣ A , C ) P ( A ∣ C ) ) P ( B ∣ C ) P(A|B,C)=\frac{P(A,B,C)}{P(B,C)}=\frac{P(B|A,C)P(A,C)}{P(B,C)}=\frac{P(B|A,C)P(A|C)P(C)}{P(B|C)P(C)}=\frac{P(B|A,C)P(A|C))}{P(B|C)} P(AB,C)=P(B,C)P(A,B,C)=P(B,C)P(BA,C)P(A,C)=P(BC)P(C)P(BA,C)P(AC)P(C)=P(BC)P(BA,C)P(AC))

7、正态分布的叠加性:当有两个独立的正态分布变量 N 1 N_{1} N1 N 2 N_{2} N2,它们的均值和方差分别为 μ 1 \mu_{1} μ1, μ 2 \mu_{2} μ2 σ 1 2 \sigma_{1}^2 σ12, σ 2 2 \sigma_{2}^2 σ22它们的和为 N = a N 1 + b N 2 N=a N_{1}+b N_{2} N=aN1+bN2的均值和方差可以表示如下:
E ( N ) = E ( a N 1 + b N 2 ) = a μ 1 + b μ 2 V a r ( N ) = V a r ( a N 1 + b N 2 ) = a 2 σ 1 2 + b 2 σ 2 2 E(N)=E(aN_{1}+bN_{2})=a\mu_{1}+b\mu_{2}\\ Var(N)=Var(aN_{1}+bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2 E(N)=E(aN1+bN2)=aμ1+bμ2Var(N)=Var(aN1+bN2)=a2σ12+b2σ22
相减时:
E ( N ) = E ( a N 1 − b N 2 ) = a μ 1 − b μ 2 V a r ( N ) = V a r ( a N 1 − b N 2 ) = a 2 σ 1 2 + b 2 σ 2 2 E(N)=E(aN_{1}-bN_{2})=a\mu_{1}-b\mu_{2}\\ Var(N)=Var(aN_{1}-bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2 E(N)=E(aN1bN2)=aμ1bμ2Var(N)=Var(aN1bN2)=a2σ12+b2σ22

8、重参数化:从 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2) 采样等价于从 N ( 0 , 1 ) N(0,1) N(0,1)采样一个 ϵ \epsilon ϵ, ϵ ⋅ σ + μ \epsilon\cdot\sigma+\mu ϵσ+μ

9、高斯分布的概率密度函数
f ( x ) = 1 2 π σ e − ( x − μ ) 2 2 σ 2 f(x)=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}} f(x)=2π σ1e2σ2(xμ)2
10、高斯分布的KL散度公式
K L ( p ∣ q ) = l o g σ 2 σ 1 + σ 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 KL(p|q)=log\frac{\sigma_2}{\sigma_1}+\frac{\sigma^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2} KL(pq)=logσ1σ2+2σ22σ2+(μ1μ2)221
11、二次函数配方
a x 2 + b x = a ( x + b 2 a ) 2 + c ax^2+bx=a(x+\frac{b}{2a})^2+c ax2+bx=a(x+2ab)2+c
12、随机变量的期望公式
X X X是随机变量, Y = g ( X ) Y=g(X) Y=g(X),则:
E ( Y ) = E [ g ( X ) ] = { ∑ k = 1 ∞ g ( x k ) p k ∫ − ∞ ∞ g ( x ) p ( x ) d x E(Y)=E[g(X)]= \begin{cases} \displaystyle\sum_{k=1}^\infty g(x_k)p_k\\ \displaystyle\int_{-\infty}^{\infty}g(x)p(x)dx \end{cases} E(Y)=E[g(X)]= k=1g(xk)pkg(x)p(x)dx

13、KL散度公式
K L ( p ( x ) ∣ q ( x ) ) = E x ∼ p ( x ) [ p ( x ) q ( x ) ] = ∫ p ( x ) p ( x ) q ( x ) d x KL(p(x)|q(x))=E_{x \sim p(x)}[\frac{p(x)}{q(x)}]=\int p(x) \frac{p(x)}{q(x)}dx KL(p(x)q(x))=Exp(x)[q(x)p(x)]=p(x)q(x)p(x)dx

介绍DDPM

2020年Berkeley大学的学生提出的DDPM(Denoising Diffusion Probabilistic Models),简称扩散模型,是AIGC的核心算法,在生成图像的真实性和多样性方面均超越了GAN,而且训练过程稳定。

扩散模型包括两个过程:前向扩散过程(前向加噪过程)反向去噪过程

img

前向过程和反向过程都是马尔可夫链,全过程大约需要1000步,其中反向过程用来生成数据,它的推导过程可以描述成:

img

前向扩散的过程

前向扩散过程是对原始数据逐渐增加高斯噪声,直至变成标准高斯分布的过程。

img

从原始数据集采样 x 0 ∼ q ( x 0 ) x_0\sim q(x_0) x0q(x0),按照预定义的noise schedule策略添加随机噪声,得到一系列噪声图像 x 1 , x 2 , … , x T x_1,x_2,\dots,x_T x1,x2,,xT,用概率表示为:
q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t I ) α t = 1 − β t \begin{aligned} q(x_{1:T}|x_{0})&=\prod_{t=1}^{T}q(x_t|x_{t-1}) \\q(x_{t}|x_{t-1})&=\mathcal{N}(x_t;\sqrt{\alpha_t}x_{t-1},\beta_{t}I)\\ \alpha_{t}&=1-\beta_{t} \end{aligned} q(x1:Tx0)q(xtxt1)αt=t=1Tq(xtxt1)=N(xt;αt xt1,βtI)=1βt
进行重参数化,得到
x t = α t x t − 1 + β t ϵ t      ϵ t ∼ N ( 0 , I ) \begin{aligned} x_{t}&=\sqrt{\alpha_{t}}x_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} \space \space \space \space \epsilon_{t}\sim \mathcal{N}(0,I) \\ \end{aligned} xt=αt xt1+βt ϵt    ϵtN(0,I)

利用上述公式进行迭代推导
x t = α t x t − 1 + β t ϵ t = α t ( α t − 1 x t − 2 + β t − 1 ϵ t − 1 ) + β t ϵ t = α t α t − 1 ( α t − 2 x t − 3 + β t − 2 ϵ t − 2 ) + α t β t − 1 ϵ t − 1 + β t ϵ t = ( α t … α 1 ) x 0 + ( α t … α 2 ) β 1 ϵ 1 + ( α t … α 3 ) β 2 ϵ 2 + ⋯ + α t β t − 1 ϵ t − 1 + β t ϵ t \begin{aligned} x_{t}&=\sqrt{\alpha_{t}} x_{t-1}+\sqrt{\beta_{t}}\epsilon_{t}\\ &=\sqrt{\alpha_{t}}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{\beta_{t-1}}\epsilon_{t-1})+\sqrt{\beta_{t}}\epsilon_{t}\\ &=\sqrt{\alpha_{t}\alpha_{t-1}}(\sqrt{\alpha_{t-2}}x_{t-3}+\sqrt{\beta_{t-2}}\epsilon_{t-2})+\sqrt{\alpha_{t}\beta_{t-1}}\epsilon_{t-1}+\sqrt{\beta_{t}}\epsilon_{t}\\ &=\sqrt{(\alpha_{t}\dots\alpha_{1})}x_{0}+\sqrt{(\alpha_{t}\dots\alpha_{2})\beta_{1}}\epsilon_{1}+\sqrt{(\alpha_{t}\dots\alpha_{3})\beta_{2}}\epsilon_{2}+\dots+\sqrt{\alpha_{t}\beta_{t-1}}\epsilon_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} \end{aligned} xt=αt xt1+βt ϵt=αt (αt1 xt2+βt1 ϵt1)+βt ϵt=αtαt1 (αt2 xt3+βt2 ϵt2)+αtβt1 ϵt1+βt ϵt=(αtα1) x0+(αtα2)β1 ϵ1+(αtα3)β2 ϵ2++αtβt1 ϵt1+βt ϵt

设: α t ˉ = α 1 α 2 … α t \bar{\alpha_{t}}=\alpha_{1}\alpha_{2}\dots\alpha_{t} αtˉ=α1α2αt 根据正态分布的叠加性得到
μ = α t ˉ σ 2 = ( α t … α 2 ) β 1 + ( α t … α 3 ) β 2 + ⋯ + α t β t − 1 + β t = ( α t … α 2 ) ( 1 − α 1 ) + ( α t … α 3 ) ( 1 − α 2 ) + ⋯ + α t ( 1 − α t − 1 ) + 1 − α t = 1 − α t ˉ \begin{aligned} \mu&=\sqrt{\bar{\alpha_{t}}}\\ \sigma^2&=(\alpha_{t}\dots\alpha_{2})\beta_{1}+(\alpha_{t}\dots\alpha_{3})\beta_{2}+\dots+\alpha_{t}\beta_{t-1}+\beta_{t}\\ &=(\alpha_{t}\dots\alpha_{2})(1-\alpha_{1})+(\alpha_{t}\dots\alpha_{3})(1-\alpha_{2})+\dots+\alpha_{t}(1-\alpha_{t-1})+1-\alpha_{t}\\ &=1-\bar{\alpha_{t}} \end{aligned} μσ2=αtˉ =(αtα2)β1+(αtα3)β2++αtβt1+βt=(αtα2)(1α1)+(αtα3)(1α2)++αt(1αt1)+1αt=1αtˉ
从而得到前向过程的最终表达式

x t = α t ˉ x 0 + 1 − α t ˉ ϵ     ϵ ∼ N ( 0 , I ) q ( x t ∣ x 0 ) = N ( x t ; α t ˉ x 0 , ( 1 − α t ˉ ) I ) \begin{aligned} x_{t}&=\sqrt{\bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha_{t}}}\epsilon \space \space\space \epsilon\sim \mathcal{N}(0,I)\\ q(x_{t}|x_{0})&=\mathcal{N}(x_{t};\sqrt{\bar{\alpha_{t}}}x_{0},(1-\bar{\alpha_{t}})I) \end{aligned} xtq(xtx0)=αtˉ x0+1αtˉ ϵ   ϵN(0,I)=N(xt;αtˉ x0,(1αtˉ)I)
这个公式表示任意步骤 t t t的噪声图像 x t x_t xt ,都可以通过 x 0 x_0 x0直接加噪得到,后面需要用到。

反向去噪过程,神经网络拟合过程

反向去噪过程就是数据生成过程,它首先是从标准高斯分布中采样得到一个噪声样本,再一步步地迭代去噪,最后得到数据分布中的一个样本。

img

如果知道反向过程的每一步真实的条件分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt),那么从一个随机噪声开始,逐步采样就能生成一个真实的样本。但是真实的条件分布利用贝叶斯公式
q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ) q ( x t ) q(x_{t-1}|x_{t}) =\frac{q(x_{t}|x_{t-1})q(x_{t-1})}{q(x_{t})} q(xt1xt)=q(xt)q(xtxt1)q(xt1)
无法直接求解,原因是其中 q ( x t − 1 ) q(x_{t-1}) q(xt1) , q ( x t ) q(x_{t}) q(xt) 未知,因此无法从 x t x_{t} xt 推导到 x t − 1 {x_{t-1}} xt1,所以必须通过神经网络 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt)来近似。为了简化起见,将反向过程也定义为一个马尔卡夫链,且服从高斯分布,建模如下:

p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(x_{0:T})=p(x_T)\prod_{t=1}^Tp_\theta(x_{t-1}|x_t)\\ p_\theta(x_{t-1}|x_t)=N(x_{t-1};\mu_\theta(x_t,t),\sum_\theta(x_t,t)) pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)pθ(xt1xt)=N(xt1;μθ(xt,t),θ(xt,t))

虽然真实条件分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt)无法直接求解,但是加上已知条件 x 0 x_0 x0的后验分布$q(x_{t-1}|x_{t},x_{0}) $却可以通过贝叶斯公式求解,再结合前向马尔科夫性质可得
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(x_{t-1}|x_{t},x_{0}) =\frac{q(x_{t}|x_{t-1},x_{0})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})}=\frac{q(x_{t}|x_{t-1})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})} q(xt1xt,x0)=q(xtx0)q(xtxt1,x0)q(xt1x0)=q(xtx0)q(xtxt1)q(xt1x0)
因此可以得到:
q ( x t − 1 ∣ x 0 ) = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 ϵ ∼ N ( α ˉ t − 1 x 0 , ( 1 − α ˉ t − 1 ) I ) q ( x t ∣ x 0 ) = α ˉ t x 0 + 1 − α ˉ t ϵ ∼ N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) q ( x t ∣ x t − 1 ) = α t x t − 1 + β t ϵ ∼ N ( α t x t − 1 , β t I ) \begin{aligned} q(x_{t-1}|x_{0})&=\sqrt{\bar{\alpha}_{t-1}}x_{0}+\sqrt{1-\bar{\alpha}_{t-1}}\epsilon\sim \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}}x_{0},(1-\bar{\alpha}_{t-1})I)\\ q(x_{t}|x_{0})&=\sqrt{\bar{\alpha}_{t}}x_{0}+\sqrt{1-\bar{\alpha}_{t}}\epsilon\sim \mathcal{N}(\sqrt{\bar{\alpha}_{t}}x_{0},(1-\bar{\alpha}_{t})I)\\ q(x_{t}|x_{t-1})&=\sqrt{\alpha}_{t}x_{t-1}+\beta_{t}\epsilon\sim \mathcal{N}(\sqrt{\alpha}_{t}x_{t-1},\beta_{t}I) \end{aligned} q(xt1x0)q(xtx0)q(xtxt1)=αˉt1 x0+1αˉt1 ϵN(αˉt1 x0,(1αˉt1)I)=αˉt x0+1αˉt ϵN(αˉt x0,(1αˉt)I)=α txt1+βtϵN(α txt1,βtI)
所以
q ( x t − 1 ∣ x t , x 0 ) ∝ e x p ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t ) + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) = e x p ( − 1 2 ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 α t ˉ 1 − α t ˉ x 0 ) x t − 1 + C ( x t , x 0 ) ) \begin{aligned} q(x_{t-1}|x_{t},x_{0}) &\propto exp(-\frac{1}{2}(\frac{(x_{t}-\sqrt{\alpha_{t}}x_{t-1})^2}{\beta_{t}})+\frac{(x_{t-1}-\sqrt{\bar{\alpha}}_{t-1}x_{0})^2}{1-\bar{\alpha}_{t-1}}-\frac{(x_{t}-\sqrt{\bar{\alpha}_{t}}x_{0})^2}{1-\bar{\alpha}_{t}})\\ &=exp(-\frac{1}{2}(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})x_{t-1}^2-(\frac{2\sqrt{\alpha_{t}}}{\beta_{t}}x_{t}+\frac{2\sqrt{\bar{\alpha_{t}}}}{1-\bar{\alpha_{t}}}x_{0})x_{t-1}+C(x_{t},x_{0})) \end{aligned} q(xt1xt,x0)exp(21(βt(xtαt xt1)2)+1αˉt1(xt1αˉ t1x0)21αˉt(xtαˉt x0)2)=exp(21(βtαt+1αˉt11)xt12(βt2αt xt+1αtˉ2αtˉ x0)xt1+C(xt,x0))

通过配方就可以得到
β ~ t = 1 / ( α t β t + 1 1 − α ˉ t − 1 ) = 1 − α ˉ t − 1 1 − α ˉ t β t μ ~ t = ( α t β t x t + α ˉ t 1 − α t ˉ x 0 ) / ( α t β t + 1 1 − α ˉ t − 1 ) = α t ( 1 − α ˉ t − 1 ) 1 − α t ˉ x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \widetilde{\beta}_t=1/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}}\beta_{t}\\ \widetilde{\mu}_t=(\frac{\sqrt\alpha_{t}}{\beta_{t}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t}}}{1-\bar{\alpha_{t}}}x_{0})/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{\sqrt{\alpha_{t}}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha_{t}}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_{t}}{1-\bar{\alpha}_{t}}x_{0} β t=1/(βtαt+1αˉt11)=1αˉt1αˉt1βtμ t=(βtα txt+1αtˉαˉt x0)/(βtαt+1αˉt11)=1αtˉαt (1αˉt1)xt+1αˉtαˉt1 βtx0

又因为
x 0 = 1 α ˉ t ( x t − β t 1 − α ˉ t ϵ ) x_0= \frac{1}{\sqrt{\bar\alpha_t}}(x_t- \frac{\beta_t}{\sqrt{1-\bar \alpha_t} }\epsilon)\\ x0=αˉt 1(xt1αˉt βtϵ)
代入上式
μ ~ t = 1 α t ( x t − β t ( 1 − α t ) ϵ ) \widetilde{\mu}_t=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon) μ t=αt 1(xt(1αt) βtϵ)

这样我们就得到了后向过程 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)的均值和方差的公式,为了让 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt)逼近 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0),我们固定住 ∑ θ ( x t , t ) = β t \sum_\theta(x_t,t)=\beta_t θ(xt,t)=βt或者 ∑ θ ( x t , t ) = β ~ t \sum_\theta(x_t,t)=\widetilde{\beta}_t θ(xt,t)=β t β t \beta_t βt β ~ t \widetilde{\beta}_t β t其实大小是差不多的,但其实用 β ~ t \widetilde{\beta}_t β t更好理解一点,因此只需要预测 μ θ ( x t , t ) \mu_\theta(x_t,t) μθ(xt,t) ,又因为这里面只有 ϵ \epsilon ϵ 是未知的,所以转而预测 ϵ \epsilon ϵ ,因此均值可以写成下面的式子:
μ θ ( x t , t ) = 1 α t ( x t − β t ( 1 − α t ) ϵ θ ( x t , t ) ) \mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t)) μθ(xt,t)=αt 1(xt(1αt) βtϵθ(xt,t))

采样过程(模型训练完后的预测过程)

p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) x t − 1 = 1 α t ( x t − β t ( 1 − α t ) ϵ θ ( x t , t ) ) + β ~ t z      z ∼ N ( 0 , I ) p_\theta(x_{t-1}|x_t)=N(x_{t-1};\mu_\theta(x_t,t),\sum_\theta(x_t,t))\\ x_{t-1}=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t))+\sqrt{\widetilde{\beta}_t}z \space \space\space\space z\sim N(0,I) pθ(xt1xt)=N(xt1;μθ(xt,t),θ(xt,t))xt1=αt 1(xt(1αt) βtϵθ(xt,t))+β t z    zN(0,I)
这里用z是为了和之前的 ϵ \epsilon ϵ区别开,迭代1000次

损失函数

https://blog.csdn.net/weixin_45453121/article/details/131223653

Code

import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np
from torch.optim import Adam
from torch import nn
import math
from torchvision.utils import save_image


def show_images(data, num_samples=20, cols=4):
    """ Plots some samples from the dataset """
    plt.figure(figsize=(15,15))
    for i, img in enumerate(data):
        if i == num_samples:
            break
        plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
        plt.imshow(img[0])


def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)


def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    #print("out:",out)
    #print("out.shape:",out.shape)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)






def load_transformed_dataset(IMG_SIZE):
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(), # Scales data into [0,1]
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
    ]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.MNIST(root="./Data",transform=data_transform,train=True)
    test = torchvision.datasets.MNIST(root="./Data", transform=data_transform, train=False)

    return torch.utils.data.ConcatDataset([train, test])

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    #Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))





class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, x, t):
        #print("ttt:",t.shape)
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels =1   #灰度图为1,彩色图为3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1   #灰度图为1 ,彩色图为3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        # Edit: Corrected a bug found by Jakub C (see YouTube comment)
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)




def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)



@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )

    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def sample_plot_image(IMG_SIZE):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 1, img_size, img_size), device=device)   #生成第T步的图片
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        #print("t:",t)
        img = sample_timestep(img, t)
        # Edit: This is to maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            plt.title(str(i))
            show_tensor_image(img.detach().cpu())
    plt.show()


if __name__ =="__main__":

    # Define beta schedule
    T = 300
    betas = linear_beta_schedule(timesteps=T)

    # Pre-calculate different terms for closed form
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    # print(alphas_cumprod.shape)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
    # print(alphas_cumprod_prev)
    # print(alphas_cumprod_prev.shape)
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    # print(posterior_variance.shape)


    IMG_SIZE = 32
    BATCH_SIZE = 16

    data = load_transformed_dataset(IMG_SIZE)
    dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    model = SimpleUnet()
    print("Num params: ", sum(p.numel() for p in model.parameters()))

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    optimizer = Adam(model.parameters(), lr=0.001)
    epochs = 1 # Try more!

    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):  #由于batch 是包含标签的所以取batch[0]
            #print(batch[0].shape)
            optimizer.zero_grad()

            t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
            loss = get_loss(model, batch[0], t)
            loss.backward()
            optimizer.step()

            if epoch % 1 == 0 and step %5== 0:
                print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
                sample_plot_image(IMG_SIZE)

参考文献

https://zhuanlan.zhihu.com/p/630354327

https://blog.csdn.net/weixin_45453121/article/details/131223653

https://www.cnblogs.com/risejl/p/17448442.html

https://zhuanlan.zhihu.com/p/569994589?utm_id=0

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

千禧皓月

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值