参考:
[1] 张振虎博客
[2] https://www.bilibili.com/video/BV1s8411i7cU/?spm_id_from=333.788&vd_source=9e9b4b6471a6e98c3e756ce7f41eb134
[3] https://zhuanlan.zhihu.com/p/660518657
[4] https://zhuanlan.zhihu.com/p/640631667
1 前言
我们在DDPM或DDIM生成图像时是通常是不可控的,因为它是由一张随即高斯噪声一步步去噪得到生成图像。如果我们想要这个过程是可控的话,最直观的一个做法就是在生成过程中加上一个条件
y
y
y,既整个过程的变为:
p
(
x
1
:
T
∣
x
0
,
y
)
p(x_{1:T}|x_0,y)
p(x1:T∣x0,y)
接下来就是讨论加上了条件
y
y
y对于公式有无影响。
首先扩散模型遵循马尔科夫链性质,所以我们可以得出:
p
(
x
t
∣
x
t
−
1
,
y
)
:
=
p
(
x
t
∣
x
t
−
1
)
p(x_t|x_{t-1},y) := p(x_t|x_{t-1})
p(xt∣xt−1,y):=p(xt∣xt−1)
基于这一事实,我们还可以推出:
q
^
(
x
t
∣
x
t
−
1
)
=
∫
y
q
^
(
x
t
,
y
∣
x
t
−
1
)
d
y
=
∫
y
q
^
(
x
t
∣
y
,
x
t
−
1
)
q
^
(
y
∣
x
t
−
1
)
d
y
=
∫
y
q
^
(
x
t
∣
x
t
−
1
)
q
^
(
y
∣
x
t
−
1
)
d
y
=
q
^
(
x
t
∣
x
t
−
1
)
=
q
^
(
x
t
∣
x
t
−
1
,
y
)
\begin{aligned} \hat q(x_t|x_{t-1}) &= \int_y \hat{q}(x_t,y|x_{t-1})dy\\ &=\int_y\hat q(x_t|y,x_{t-1})\hat q(y|x_{t-1})dy\\ &= \int_y \hat q(x_t|x_{t-1})\hat q(y|x_{t-1})dy\\ & = \hat q(x_t|x_{t-1}) = \hat q(x_t|x_{t-1},y) \end{aligned}
q^(xt∣xt−1)=∫yq^(xt,y∣xt−1)dy=∫yq^(xt∣y,xt−1)q^(y∣xt−1)dy=∫yq^(xt∣xt−1)q^(y∣xt−1)dy=q^(xt∣xt−1)=q^(xt∣xt−1,y)
同样用全概率公式,可以推出:
所以我们可以断论:加上条件
y
y
y 对前向过程毫无影响
2 Classifier Guidance
逆向过程有如下公式:
p
^
(
x
t
−
1
∣
x
t
,
y
)
=
p
^
(
x
t
−
1
∣
x
t
)
p
^
(
y
∣
x
t
−
1
,
x
t
)
p
^
(
y
∣
x
t
)
\hat p(x_{t-1}|x_t,y)=\frac{\hat p(x_{t-1}|x_t)\hat p(y|x_{t-1},x_t)}{\hat p(y|x_t)}
p^(xt−1∣xt,y)=p^(y∣xt)p^(xt−1∣xt)p^(y∣xt−1,xt)
其中分母和
x
t
−
1
x_{t-1}
xt−1毫无关系,所以分母可以看作是常数
C
C
C。
而我们知道加上条件对于扩散过程是没有影响的,所以我们还已知:
q
^
(
x
t
∣
x
t
−
1
,
y
)
=
q
(
x
t
∣
x
t
−
1
)
q
^
(
x
0
)
=
q
(
x
0
)
q
^
(
x
1
:
T
∣
x
0
,
y
)
=
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
,
y
)
\hat q(x_t|x_{t-1},y) = q(x_t|x_{t-1})\\ \hat q(x_0)=q(x_0)\\ \hat q(x_{1:T}|x_0,y) = \prod_{t=1}^Tq(x_t|x_{t-1},y)
q^(xt∣xt−1,y)=q(xt∣xt−1)q^(x0)=q(x0)q^(x1:T∣x0,y)=t=1∏Tq(xt∣xt−1,y)
现在我们未知的是 p ^ ( x t − 1 ∣ x t ) 和 p ^ ( y ∣ x t − 1 , x t ) \hat p(x_{t-1}|x_t)和\hat p(y|x_{t-1},x_t) p^(xt−1∣xt)和p^(y∣xt−1,xt),现在来推导这两项:
1)推导
p
^
(
x
t
−
1
∣
x
t
)
\hat p(x_{t-1}|x_t)
p^(xt−1∣xt)
根据贝叶斯公式,我们有:
p
^
(
x
t
−
1
∣
x
t
)
=
p
^
(
x
t
∣
x
t
−
1
)
p
^
(
x
t
−
1
)
p
^
(
x
t
)
\hat p(x_{t-1}|x_t) = \frac{\hat p(x_t|x_{t-1})\hat p(x_{t-1})}{\hat p(x_t)}
p^(xt−1∣xt)=p^(xt)p^(xt∣xt−1)p^(xt−1)
我们已知条件
y
y
y 对扩散过程不影响(可以通过全概率公式推出),所以我们有
p
^
(
x
t
∣
x
t
−
1
)
=
p
(
x
t
∣
x
t
−
1
)
\hat p(x_t|x_{t-1})=p(x_t|x_{t-1})
p^(xt∣xt−1)=p(xt∣xt−1)
我们同样可以由全概率公式推出:
p
^
(
x
t
)
=
p
(
x
t
)
\hat p(x_t) = p(x_t)
p^(xt)=p(xt)
所以
p
^
(
x
t
−
1
∣
x
t
)
=
p
(
x
t
∣
x
t
−
1
)
p
(
x
t
−
1
)
p
(
x
t
)
\hat p(x_{t-1}|x_t) = \frac{p(x_t|x_{t-1})p(x_{t-1})}{p(x_t)}
p^(xt−1∣xt)=p(xt)p(xt∣xt−1)p(xt−1)
2)推导 p ^ ( y ∣ x t − 1 , x t ) \hat p(y|x_{t-1},x_t) p^(y∣xt−1,xt)
根据贝叶斯公式,有:
p
^
(
y
∣
x
t
−
1
,
x
t
)
=
p
^
(
x
t
∣
y
,
x
t
−
1
)
p
^
(
y
∣
x
t
−
1
)
p
^
(
x
t
∣
x
t
−
1
)
\hat p(y|x_{t-1},x_t)=\frac{\hat p(x_t|y,x_{t-1})\hat p(y|x_{t-1})}{\hat p(x_t|x_{t-1})}
p^(y∣xt−1,xt)=p^(xt∣xt−1)p^(xt∣y,xt−1)p^(y∣xt−1)
根据马尔可夫链性质,所以约去分子的第一项和分母,所以得到
p
^
(
y
∣
x
t
−
1
,
x
t
)
=
p
^
(
y
∣
x
t
−
1
)
\hat p(y|x_{t-1},x_t) = \hat p(y|x_{t-1})
p^(y∣xt−1,xt)=p^(y∣xt−1)
3)终极目标
所以我们的公式此刻为:
p
^
(
x
t
−
1
∣
x
t
,
y
)
=
q
(
x
t
−
1
∣
x
t
)
q
(
y
∣
x
t
−
1
)
q
(
y
∣
x
t
)
=
C
∗
q
(
x
t
−
1
∣
x
t
)
∗
q
(
y
∣
x
t
−
1
)
\hat p(x_{t-1}|x_t,y) = \frac{q(x_{t-1}|x_t)q(y|x_{t-1})}{q(y|x_t)} = C*q(x_{t-1}|x_t)*q(y|x_{t-1})
p^(xt−1∣xt,y)=q(y∣xt)q(xt−1∣xt)q(y∣xt−1)=C∗q(xt−1∣xt)∗q(y∣xt−1)
其中第一项是常数,第二项为DDPM的目标,第三项既为分类器输出概率(根据
x
t
−
1
x_{t-1}
xt−1输出类别标签
y
y
y)
4)问题与进一步推导
我们此刻为
t
t
t时刻,我们是不可以得出
x
t
−
1
x_{t-1}
xt−1的。但是我们只是每一次从
x
t
x_t
xt到
x
t
−
1
x_{t-1}
xt−1实际上只做了很微小的变化,所以我们是可以近似
x
t
−
1
x_{t-1}
xt−1的,用泰勒展开式去近似。
我们有
l
o
g
p
θ
(
x
t
−
1
∣
x
t
)
=
−
1
2
(
x
t
−
1
−
μ
)
2
Σ
logp_\theta(x_{t-1}|x_t) = -\frac{1}{2}\frac{(x_{t-1} -\mu)^2}{\Sigma}
logpθ(xt−1∣xt)=−21Σ(xt−1−μ)2
而
Σ
\Sigma
Σ是很小的,我们可以理解为
x
t
−
1
x_{t-1}
xt−1出现在
x
t
x_t
xt的附近,而
x
t
x_t
xt约等于期望
所以令
x
t
−
1
=
μ
x_{t-1}=\mu
xt−1=μ
有
l
o
g
p
ϕ
(
y
∣
x
t
−
1
)
=
l
o
g
p
ϕ
(
y
∣
x
t
−
1
)
∣
x
t
−
1
=
μ
+
(
x
t
−
1
−
μ
)
∇
x
t
−
1
l
o
g
p
ϕ
(
y
∣
x
t
−
1
)
∣
x
t
−
1
=
μ
+
o
(
高阶
)
logp_\phi(y|x_{t-1}) = logp_\phi(y|x_{t-1})_{|x_{t-1}=\mu} +(x_{t-1}-\mu)\nabla_{x_{t-1}}logp_\phi(y|x_{t-1})_{|x_{t-1}=\mu} +o(高阶)
logpϕ(y∣xt−1)=logpϕ(y∣xt−1)∣xt−1=μ+(xt−1−μ)∇xt−1logpϕ(y∣xt−1)∣xt−1=μ+o(高阶)
又第一项和后面的高阶项相当于常数,所以约等于
l
o
g
p
ϕ
(
y
∣
x
t
−
1
)
=
(
x
t
−
1
−
μ
)
∇
x
t
−
1
l
o
g
p
ϕ
(
y
∣
x
t
−
1
)
∣
x
t
−
1
=
μ
logp_\phi(y|x_{t-1}) = (x_{t-1}-\mu)\nabla_{x_{t-1}}logp_\phi(y|x_{t-1})_{|x_{t-1}=\mu}
logpϕ(y∣xt−1)=(xt−1−μ)∇xt−1logpϕ(y∣xt−1)∣xt−1=μ
我们将两个对数相加,一番推导后(我不会)可以得到
l
o
g
p
(
x
t
−
1
∣
x
t
,
y
)
∼
N
(
μ
+
Σ
∇
l
o
g
p
ϕ
(
y
∣
x
t
−
1
)
∣
x
t
−
1
=
μ
)
logp(x_{t-1}|x_t,y) \sim N(\mu+\Sigma \nabla log p_\phi(y|x_{t-1})_{|x_{t-1}=\mu})
logp(xt−1∣xt,y)∼N(μ+Σ∇logpϕ(y∣xt−1)∣xt−1=μ)
既采样时,有
x
t
−
1
=
μ
+
Σ
∇
+
Σ
ϵ
x_{t-1} = \mu+\Sigma\nabla+\Sigma\epsilon
xt−1=μ+Σ∇+Σϵ
其中
∇
\nabla
∇为分类器的梯度,所以这么一番推导,我们只是在最后的采样公式里加了一个引导方向的梯度项。
但有个缺点就是,DDIM的
Σ
=
0
\Sigma=0
Σ=0,那么不就没用了。
而且还有两个缺点:
- 还要预训练一个分类器模型
- 只能生成分类器训练集所有的类别
5)用能量函数(score-base function)做进一步泛化
已知
s
=
∇
x
t
l
o
g
p
(
x
t
)
s = \nabla_{x_t}log p(x_t)
s=∇xtlogp(xt)
我们已知梯度和噪声的关系为:
∇
x
t
l
o
g
p
(
x
t
)
=
−
ϵ
1
−
α
ˉ
t
\nabla_{x_t}log p(x_t) = \frac{-\epsilon}{\sqrt{1-\bar\alpha_t}}
∇xtlogp(xt)=1−αˉt−ϵ
如果没有classifier guidance,那么我们的神经网络想要预测的就是
∇
x
t
l
o
g
p
θ
(
x
t
)
=
−
ϵ
θ
1
−
α
ˉ
t
\nabla_{x_t}log p_\theta(x_t) = \frac{-\epsilon_\theta}{\sqrt{1-\bar\alpha_t}}
∇xtlogpθ(xt)=1−αˉt−ϵθ
现在加上classifier guidance,也就是加上了
∇
l
o
g
p
θ
(
y
∣
x
t
)
\nabla logp_\theta(y|x_t)
∇logpθ(y∣xt),假设其值为
g
g
g(为了方便,他就是分类器梯度)
其实我们神经网络实际上是预测:
−
ϵ
θ
1
−
α
ˉ
t
+
g
\frac{-\epsilon_\theta}{\sqrt{1-\bar\alpha_t}} +g
1−αˉt−ϵθ+g
我们设其在预测
−
ϵ
′
1
−
α
ˉ
t
\frac{-\epsilon'}{\sqrt{1-\bar\alpha_t}}
1−αˉt−ϵ′
将其取等式,然后做变换,再加上一个强度因子
w
w
w,得到
ϵ
′
=
ϵ
θ
−
w
1
−
α
ˉ
t
∇
l
o
g
p
ϕ
(
y
∣
x
t
)
\epsilon' = \epsilon_\theta -w\sqrt{1-\bar\alpha_t}\nabla log p_\phi(y|x_{t})
ϵ′=ϵθ−w1−αˉt∇logpϕ(y∣xt)
也就是说,我们只需要在预测的噪声上加上一点扰动即可。而扰动项为分类器的梯度。
6)纯能量函数角度推导
原论文的伪代码的两个算法也是我们推导的:
7)不严谨代码理解
classifier_model = ... # 加载一个训好的图像分类模型
y = 1 # 生成类别为 1 的图像,假设类别 1 对应“狗”这个类
guidance_scale = 7.5 # 控制类别引导的强弱,越大越强
input = get_noise(...) # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图
for t in tqdm(scheduler.timesteps):
# 用 unet 推理,预测噪声
with torch.no_grad():
noise_pred = unet(input, t).sample
# 用 input 和预测出的 noise_pred 和 x_t 计算得到 x_t-1
input = scheduler.step(noise_pred, t, input).prev_sample
# classifier guidance 步骤
class_guidance = classifier_model.get_class_guidance(input, y)
input += class_guidance * guidance_scals # 把梯度加上去
3 Classifier Free Guidance
我们知道对于classifier guidance最主要的限制就是分类器!CFG的方法就是直接将条件
y
y
y也加到模型中直接训练,而不用在训练一个分类器了,这样相当于训练了一个隐式的分类器,也就是训练了无条件生成模型和有条件生成模型,只不过这两个模型融合在同一个生成模型里。数学推导如下:
这是CF最终的推导:
这是CFG最终的推导:
不严谨代码理解
clip_model = ... # 加载一个官方的 clip 模型
text = "一只狗" # 输入文本
text_embeddings = clip_model.text_encode(text) # 编码条件文本
empty_embeddings = clip_model.text_encode("") # 编码空文本
text_embeddings = torch.cat(empty_embeddings, text_embeddings) # 把它俩 concate 到一起作为条件
input = get_noise(...) # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图
for t in tqdm(scheduler.timesteps):
# 用 unet 推理,预测噪声
with torch.no_grad():
# 这里同时预测出了有文本的和空文本的图像噪声
noise_pred = unet(input, t, encoder_hidden_states=text_embeddings).sample
# Classifier-Free Guidance 引导
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # 拆成无条件和有条件的噪声
# 把【“无条件噪声”指向“有条件噪声”】看做一个向量,根据 guidance_scale 的值放大这个向量
# (当 guidance_scale = 1 时,下面这个式子退化成 noise_pred_text)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# 用预测出的 noise_pred 和 x_t 计算得到 x_t-1
input = scheduler.step(noise_pred, t, input).prev_sample
CFG代码(DDPM的基础上添加)
CFG体现在代码上最重要的部分为:
ϵ
t
=
ϵ
c
+
w
(
ϵ
c
−
ϵ
ϕ
)
\epsilon_t = \epsilon_c+w(\epsilon_c-\epsilon_\phi)
ϵt=ϵc+w(ϵc−ϵϕ)
思路:
(model上的思路:)
首先使用embedding将条件
y
y
y连同timestep
t
t
t一同嵌入到embedding space(保证维度一致),之后再相加,这样就把条件也嵌入到模型里了。
class UNet_conditional(nn.Module):
def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda"): # 新增了num_classes
...
...
...
if num_classes is not None: # 新增
self.label_emb = nn.Embedding(num_classes, time_dim)
...
...
...
def forward(self, x, t, y):
t = t.unsqueeze(-1).type(torch.float)
t = self.pos_encoding(t, self.time_dim)
if y is not None:
t += self.label_emb(y)
(训练代码上:)
- 模型的定义要新增参数 num_classes
- 模型的调用要新增参数 y
model = UNet_conditional(num_classes=args.num_classes).to(device) # 新增
for i, (images, labels) in enumerate(pbar):
images = images.to(device)
labels = labels.to(device) # 新增labels
t = diffusion.sample_timesteps(images.shape[0]).to(device)
x_t, noise = diffusion.noise_images(images, t)
if np.random.random() < 0.1: # 10%的时间使用unconditional
labels = None
predicted_noise = model(x_t, t, labels) # 新增condition
loss = mse(noise, predicted_noise)
(采样上的思路)
原本:
- 随机采样正态分布噪声
- 由T到0进行训练:
1)传入 x t , t x_t, t xt,t到模型中,预测predicted_noise
2)根据 x t x_t xt和 predicted_noise,预测重建x_0
3)根据 x t x_t xt和 x 0 x_0 x0,计算出 p θ p_\theta pθ的均值和方差
4)计算出 x t − 1 x_{t-1} xt−1,继续下一个循环.
现在:
- 随机采样正态分布噪声
- 由T到0进行训练:
1)传入 x t , t , x_t, t, xt,t,y到模型中,预测predicted_noise
2)如果cfg_scale>0,传入 x t , t x_t, t xt,t,None预测出uncon_predicted_noise, 使用torch.lerp()函数实现以上式子。
3)根据 x t x_t xt和 predicted_noise,预测重建x_0
4)根据 x t x_t xt和 x 0 x_0 x0,计算出 p θ p_\theta pθ的均值和方差
5)计算出 x t − 1 x_{t-1} xt−1,继续下一个循环。
def p_mean_variance(self,model,x_t,t,y,cfg_scale=None,clip_denoised=True):
pred_noise = model(x_t,t,y)
if cfg_scale > 0:
uncon_pred_noise = model(x_t,t,None)
pred_noise = torch.lerp(uncon_pred_noise,pred_noise,cfg_scale)
x_recon = self.estimate_x0_from_noise(x_t,t,pred_noise)
if clip_denoised:
x_recon = torch.clamp(x_recon,min=-1.,max=1.)
p_mean,p_var = self.q_posterior_mean_variance(x_recon,x_t,t)
return p_mean,p_var