Forward Process
Defined as Markov Chain: q ( x 1 : T | x 0 ) = ∏ t = 1 T q ( x t | x t − 1 , x t − 2 , ⋯ , x 0 ) = ∏ t = 1 T q ( x t | x t − 1 ) q\left({\bold x}_{1:T}\middle\vert{\bold x}_0\right)=\prod_{t=1}^T{q\left({\bold x}_t\middle\vert{\bold x}_{t-1},{\bold x}_{t-2},\cdots,{\bold x}_0\right)}=\prod_{t=1}^T{q\left({\bold x}_t\middle\vert{\bold x}_{t-1}\right)} q(x1:T∣x0)=t=1∏Tq(xt∣xt−1,xt−2,⋯,x0)=t=1∏Tq(xt∣xt−1)where q ( x t | x t − 1 ) = N ( x t ; 1 − β t ⋅ x t − 1 , β t I ) q(\left.{\bold x}_t \middle\vert{\bold x}_{t-1}\right.)={\cal N}\left({\bold x}_t;\sqrt{1-\beta_t}\cdot{\bold x}_{t-1},\beta_t{\bold I}\right) q(xt∣xt−1)=N(xt;1−βt⋅xt−1,βtI)
Reparameterization Trick
x t = 1 − β t ⋅ x t − 1 + β t ⋅ ϵ t {\bold x}_t=\sqrt{1-\beta_t}\cdot{\bold x}_{t-1}+\sqrt{\beta_t}\cdot{\boldsymbol\epsilon}_t xt=1−βt⋅xt−1+βt⋅ϵtwhere ϵ t ∼ N ( 0 , I ) {\boldsymbol\epsilon}_t\sim{\cal N}\left({\bold 0},{\bold I}\right) ϵt∼N(0,I)
Why μ 2 + σ 2 = 1 \mu^2+\sigma^2=1 μ2+σ2=1
x t = 1 − β t ( 1 − β t − 1 ⋅ x t − 2 + β t − 1 ⋅ ϵ t − 1 ) + β t ⋅ ϵ t = ( 1 − β t ) ( 1 − β t − 1 ) ⋅ x t − 2 + 1 − ( 1 − β t ) ( 1 − β t − 1 ) ⋅ ϵ ′ = ⋯ = ∏ i = 1 t ( 1 − β i ) ⋅ x 0 + 1 − ∏ i = 1 t ( 1 − β i ) ⋅ ϵ ′ ′ \begin{aligned} {\bold x}_t &=\sqrt{1-\beta_t}\left(\sqrt{1-\beta_{t-1}}\cdot{\bold x}_{t-2}+\sqrt{\beta_{t-1}}\cdot{\boldsymbol\epsilon}_{t-1}\right)+\sqrt{\beta_t}\cdot{\boldsymbol\epsilon}_t \\ &=\sqrt{(1-\beta_t)(1-\beta_{t-1})}\cdot{\bold x}_{t-2}+\sqrt{1-(1-\beta_t)(1-\beta_{t-1})}\cdot{\boldsymbol\epsilon}' \\ &=\cdots \\ &=\sqrt{\prod_{i=1}^{t}\left(1-\beta_i\right)}\cdot{\bold x}_0+\sqrt{1-\prod_{i=1}^{t}\left(1-\beta_i\right)}\cdot{\boldsymbol\epsilon}'' \end{aligned} xt=1−βt(1−βt−1⋅xt−2+βt−1⋅ϵt−1)+βt⋅ϵt=(1−βt)(1−βt−1)⋅xt−2+1−(1−βt)(1−βt−1)⋅ϵ′=⋯=i=1∏t(1−βi)⋅x0+1−i=1∏t(1−βi)⋅ϵ′′where ϵ ′ , ϵ ′ ′ ∼ N ( 0 , I ) {\boldsymbol\epsilon}',{\boldsymbol\epsilon}''\sim{\cal N}\left({\bold 0},{\bold I}\right) ϵ′,ϵ′′∼N(0,I)let α t = 1 − β t \alpha_t=1-\beta_t αt=1−βtand α ˉ t = ∏ s = 1 t α s \bar\alpha_t=\prod_{s=1}^{t}\alpha_s αˉt=s=1∏tαswe have q ( x t | x 0 ) = N ( x t ; α ˉ t ⋅ x 0 , ( 1 − α ˉ t ) I ) q(\left.{\bold x}_t \middle\vert{\bold x}_0\right.)={\cal N}\left({\bold x}_t;\sqrt{\bar\alpha_t}\cdot{\bold x}_0,(1-\bar\alpha_t){\bold I}\right) q(xt∣x0)=N(xt;αˉt⋅x0,(1−αˉt)I)
Reverse Process
Defined as Markov Chain as well: p θ ( x 0 : T ) = p θ ( x T ) ∏ t = 1 T p θ ( x t − 1 | x t ) p_\theta({\bold x}_{0:T})=p_\theta({\bold x}_T)\prod_{t=1}^T{p_\theta\left({\bold x}_{t-1}\middle\vert{\bold x}_{t}\right)} pθ(x0:T)=pθ(xT)t=1∏Tpθ(xt−1∣xt)where p θ ( x t − 1 | x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta\left({\bold x}_{t-1}\middle\vert{\bold x}_t\right)={\cal N}\left({\bold x}_{t-1};{\boldsymbol\mu}_\theta\left({\bold x}_t,t\right),{\boldsymbol\Sigma}_\theta\left({\bold x}_t,t\right)\right) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
From Forward Process
q ( x t − 1 | x t ) = q ( x t − 1 | x t , x 0 ) = q ( x t | x t − 1 , x 0 ) ⋅ q ( x t − 1 | x 0 ) q ( x t | x 0 ) q\left({\bold x}_{t-1}\middle\vert{\bold x}_t\right)=q\left({\bold x}_{t-1}\middle\vert{\bold x}_t,{\bold x}_0\right)=q(\left.{\bold x}_t \middle\vert{\bold x}_{t-1},{\bold x}_0\right.)\cdot\frac{q(\left.{\bold x}_{t-1} \middle\vert{\bold x}_0\right.)}{q(\left.{\bold x}_t \middle\vert{\bold x}_0\right.)} q(xt−1∣xt)=q(xt−1∣xt,x0)=q(xt∣xt−1,x0)⋅q(xt∣x0)q(xt−1∣x0)with Gaussian kernel log q ( x t − 1 | x t , x 0 ) = − 1 2 [ ( x t − α t ⋅ x t − 1 ) 2 β t + ( x t − 1 − α ˉ t − 1 ⋅ x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t ⋅ x 0 ) 2 1 − α ˉ t ] = − 1 2 [ ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t ⋅ x t + 2 α ˉ t − 1 1 − α ˉ t − 1 ⋅ x 0 ) x t − 1 + C ] \begin{aligned} \log{q\left({\bold x}_{t-1}\middle\vert{\bold x}_t,{\bold x}_0\right)} &=-\frac12\left[\frac{\left({\bold x}_t-\sqrt{\alpha_t}\cdot{\bold x}_{t-1}\right)^2}{\beta_t}+\frac{\left({\bold x}_{t-1}-\sqrt{\bar\alpha_{t-1}}\cdot{\bold x}_0\right)^2}{1-\bar\alpha_{t-1}}-\frac{\left({\bold x}_t-\sqrt{\bar\alpha_t}\cdot{\bold x}_0\right)^2}{1-\bar\alpha_t}\right] \\ &=-\frac12\left[\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar\alpha_{t-1}}\right){\bold x}_{t-1}^2-\left(\frac{2\sqrt{\alpha_t}}{\beta_t}\cdot{\bold x}_t+\frac{2\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}\cdot{\bold x}_0\right){\bold x}_{t-1}+C\right] \end{aligned} logq(xt−1∣xt,x0)=−21[βt(xt−αt⋅xt−1)2+1−αˉt−1(xt−1−αˉt−1⋅x0)2−1−αˉt(xt−αˉt⋅x0)2]=−21[(βtαt+1−αˉt−11)xt−12−(βt2αt⋅xt+1−αˉt−12αˉt−1⋅x0)xt−1+C]therefore 1 σ 2 = α t − α ˉ t + β t β t ( 1 − α ˉ t − 1 ) = 1 − α ˉ t 1 − α ˉ t − 1 ⋅ 1 β t ⟹ σ 2 = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t = Δ β ~ t \frac{1}{\sigma^2}=\frac{\alpha_t-\bar\alpha_t+\beta_t}{\beta_t\left(1-\bar\alpha_{t-1}\right)}=\frac{1-\bar\alpha_t}{1-\bar\alpha_{t-1}}\cdot\frac{1}{\beta_t} \Longrightarrow \sigma^2=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\cdot\beta_t\xlongequal[]{\Delta}\tilde\beta_t σ21=βt(1−αˉt−1)αt−αˉt+βt=1−αˉt−11−αˉt⋅βt1⟹σ2=1−αˉt1−αˉt−1⋅βtΔβ~t μ = σ 2 2 ( 2 α t β t ⋅ x t + 2 α ˉ t − 1 1 − α ˉ t − 1 ⋅ x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t ⋅ x t + β t α ˉ t − 1 1 − α ˉ t ⋅ x 0 = Δ μ ~ t ( x t , x 0 ) \mu=\frac{\sigma^2}{2}\left(\frac{2\sqrt{\alpha_t}}{\beta_t}\cdot{\bold x}_t+\frac{2\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}\cdot{\bold x}_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar\alpha_{t-1}\right)}{1-\bar\alpha_t}\cdot{\bold x}_t+\frac{\beta_t\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_t}\cdot{\bold x}_0\xlongequal[]{\Delta}\tilde{\boldsymbol\mu}_t({\bold x}_t,{\bold x}_0) μ=2σ2(βt2αt⋅xt+1−αˉt−12αˉt−1⋅x0)=1−αˉtαt(1−αˉt−1)⋅xt+1−αˉtβtαˉt−1⋅x0Δμ~t(xt,x0)finally q ( x t − 1 | x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) q\left({\bold x}_{t-1}\middle\vert{\bold x}_t,{\bold x}_0\right)={\cal N}\left({\bold x}_{t-1};\tilde{\boldsymbol\mu}_t({\bold x}_t,{\bold x}_0),\tilde\beta_t{\bold I}\right) q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)
Noise Prediction
For given noise ϵ ∼ N ( 0 , I ) {\boldsymbol\epsilon}\sim{\cal N}({\bold 0},{\bold I}) ϵ∼N(0,I)we have x t ( x 0 , ϵ ) = α ˉ t ⋅ x 0 + 1 − α ˉ t ⋅ ϵ ⟹ x 0 = x t ( x 0 , ϵ ) − 1 − α ˉ t ⋅ ϵ α ˉ t {\bold x}_t({\bold x}_0,{\boldsymbol\epsilon})=\sqrt{\bar\alpha_t}\cdot{\bold x}_0+\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon} \Longrightarrow {\bold x}_0=\frac{{\bold x}_t({\bold x}_0,{\boldsymbol\epsilon})-\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon}}{\sqrt{\bar\alpha_t}} xt(x0,ϵ)=αˉt⋅x0+1−αˉt⋅ϵ⟹x0=αˉtxt(x0,ϵ)−1−αˉt⋅ϵthus, w.r.t. x t {\bold x}_t xt and ϵ \boldsymbol\epsilon ϵ μ ~ t ( x t , x t − 1 − α ˉ t ⋅ ϵ α ˉ t ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t ⋅ x t + β t α ˉ t − 1 1 − α ˉ t ⋅ x t − 1 − α ˉ t ⋅ ϵ α ˉ t = α t − α ˉ t + β t ( 1 − α ˉ t ) α t ⋅ x t − β t 1 − α ˉ t α t ⋅ ϵ = 1 α t ( x t − β t 1 − α ˉ t ⋅ ϵ ) \begin{aligned} \tilde{\boldsymbol\mu}_t\left({\bold x}_t,\frac{{\bold x}_t-\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon}}{\sqrt{\bar\alpha_t}}\right) &=\frac{\sqrt{\alpha_t}\left(1-\bar\alpha_{t-1}\right)}{1-\bar\alpha_t}\cdot{\bold x}_t+\frac{\beta_t\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_t}\cdot\frac{{\bold x}_t-\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon}}{\sqrt{\bar\alpha_t}} \\ &=\frac{\alpha_t-\bar\alpha_t+\beta_t}{\left(1-\bar\alpha_t\right)\sqrt{\alpha_t}}\cdot{\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}\sqrt{\alpha_t}}\cdot{\boldsymbol\epsilon} \\ &=\frac{1}{\sqrt{\alpha_t}}\left({\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot{\boldsymbol\epsilon}\right) \end{aligned} μ~t(xt,αˉtxt−1−αˉt⋅ϵ)=1−αˉtαt(1−αˉt−1)⋅xt+1−αˉtβtαˉt−1⋅αˉtxt−1−αˉt⋅ϵ=(1−αˉt)αtαt−αˉt+βt⋅xt−1−αˉtαtβt⋅ϵ=αt1(xt−1−αˉtβt⋅ϵ)parameterize as neural network ϵ = ϵ θ ( x t , t ) {\boldsymbol\epsilon}={\boldsymbol\epsilon}_\theta({\bold x}_t,t) ϵ=ϵθ(xt,t)finally μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ⋅ ϵ θ ( x t , t ) ) {\boldsymbol\mu}_\theta\left({\bold x}_t,t\right)=\frac{1}{\sqrt{\alpha_t}}\left({\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot{\boldsymbol\epsilon}_\theta({\bold x}_t,t)\right) μθ(xt,t)=αt1(xt−1−αˉtβt⋅ϵθ(xt,t))
Loss Function
Recap p θ ( x t − 1 | x t ) = N ( x t − 1 ; μ θ ( x t , t ) , σ t 2 I ) p_\theta\left({\bold x}_{t-1}\middle\vert{\bold x}_t\right)={\cal N}\left({\bold x}_{t-1};{\boldsymbol\mu}_\theta\left({\bold x}_t,t\right),\sigma_t^2{\bold I}\right) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),σt2I)where σ t 2 = β t o r β ~ t \sigma_t^2=\beta_t \ {\rm or} \ \tilde\beta_t σt2=βt or β~tusing KL divergence L t − 1 = K L ( q ( x t − 1 | x t , x 0 ) ∥ p θ ( x t − 1 | x t ) ) = E q [ 1 2 σ t 2 ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] = E x 0 , ϵ [ 1 2 σ t 2 ∥ 1 α t ( x t − β t 1 − α ˉ t ⋅ ϵ ) − 1 α t ( x t − β t 1 − α ˉ t ⋅ ϵ θ ( x t , t ) ) ∥ 2 ] = E x 0 , ϵ [ β t 2 2 σ t 2 α t ( 1 − α ˉ t ) ∥ ϵ − ϵ θ ( α ˉ t ⋅ x 0 + 1 − α ˉ t ⋅ ϵ , t ) ∥ 2 ] \begin{aligned} {\cal L}_{t-1} &=\mathop{\rm KL}\left(q\left({\bold x}_{t-1}\middle\vert{\bold x}_t,{\bold x}_0\right)\middle\Vert p_\theta\left({\bold x}_{t-1}\middle\vert{\bold x}_t\right)\right) \\ &={\bf E}_q\left[\left.\frac{1}{2\sigma_t^2}\middle\Vert\tilde{\boldsymbol\mu}_t({\bold x}_t,{\bold x}_0)-{\boldsymbol\mu}_\theta\left({\bold x}_t,t\right)\right\Vert^2\right] \\ &={\bf E}_{{\bold x}_0,{\boldsymbol\epsilon}}\left[\left.\frac{1}{2\sigma_t^2}\middle\Vert\frac{1}{\sqrt{\alpha_t}}\left({\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot{\boldsymbol\epsilon}\right)-\frac{1}{\sqrt{\alpha_t}}\left({\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot{\boldsymbol\epsilon}_\theta({\bold x}_t,t)\right)\right\Vert^2\right] \\ &={\bf E}_{{\bold x}_0,{\boldsymbol\epsilon}}\left[\left.\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar\alpha_t)}\middle\Vert{\boldsymbol\epsilon}-{\boldsymbol\epsilon}_\theta\left(\sqrt{\bar\alpha_t}\cdot{\bold x}_0+\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon},t\right)\right\Vert^2\right] \end{aligned} Lt−1=KL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))=Eq[2σt21 μ~t(xt,x0)−μθ(xt,t) 2]=Ex0,ϵ[2σt21 αt1(xt−1−αˉtβt⋅ϵ)−αt1(xt−1−αˉtβt⋅ϵθ(xt,t)) 2]=Ex0,ϵ[2σt2αt(1−αˉt)βt2 ϵ−ϵθ(αˉt⋅x0+1−αˉt⋅ϵ,t) 2]a simplified version (w/ no coefficient) L s i m p = E x 0 , ϵ [ ∥ ϵ − ϵ θ ( α ˉ t ⋅ x 0 + 1 − α ˉ t ⋅ ϵ , t ) ∥ 2 ] {\cal L}_{\rm simp}={\bf E}_{{\bold x}_0,{\boldsymbol\epsilon}}\left[\left\Vert{\boldsymbol\epsilon}-{\boldsymbol\epsilon}_\theta\left(\sqrt{\bar\alpha_t}\cdot{\bold x}_0+\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon},t\right)\right\Vert^2\right] Lsimp=Ex0,ϵ[ ϵ−ϵθ(αˉt⋅x0+1−αˉt⋅ϵ,t) 2]