前景
为了系统地学习diffusion model的理论与实现,开个贴记录学习。主要侧重在图生图方向。
Diffusion model
扩散模型包括两个过程:前向过程(forward process)和反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process)。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),这是因为当前状态只和上个状态有关,通过最初始状态以及t时刻就能根据函数得到该时刻下的状态。其中反向过程可用于生成数据样本,通过设定最大迭代次数逐步去噪实现图像生成。
X0就是原图,XT表示X0经过T轮的噪声叠加得到的噪声图(前向过程)。
前向过程
根据马尔科夫链,xt由xt-1得到:
β是t的函数,一般设置线性或者正余弦等函数关系,用来控制加入噪声的强度。根据以上公式,可以通过重参数化采样得到xt
逆向过程
该过程就是逐步去噪的过程,这一步就是利用神经网络模拟的:
这个分布无法直接求出,但是根据马尔科夫链的规则,可以直接加入x0依旧等价,接着就可进行求解:第一二步根据贝叶斯多变量概率分布公式得到,第三步根据假设,t只和t-1时刻状态有关因此x0可以直接拿掉。第四步带入高斯分布概率密度函数,我们需要求均值和方差,那么只需要看指数部分。
方差只和设定的参数有关因此是个常数
均值可以看到只和xt和x0有关,x0又可以通过xt进行转换(xt通过x0和噪声链式叠加得到)
得到这个结果之后就可以通过算数据目标分布的似然函数构建loss求解
对于单一变量的q和p,kl散度直接是
u展开,相减之后只剩下eps,xt转为x0和t还有eps的表达形式得到最终的loss。模型估计的就是eps。
得到估计的eps就能进行采样
根据逆扩散过程公式:
因为假设分布为高斯分布且满足马尔科夫链,因此得到均值和方差之后可以通过重参数技巧根据t时刻的均值和方差得到t-1时刻的分布。
DDPM论文假设方差是已知的,只需要利用神经网络学习均值,具体推理过程:扩散模型 Diffusion Models - 原理篇 - 荏苒岁月的文章 - 知乎
https://zhuanlan.zhihu.com/p/548112711
代码部分
首先需要对t时刻进行时间步编码,这里以正余弦编码为例:
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb)
pos = torch.arange(T).float()
emb = pos[:, None] * emb[None, :]
assert list(emb.shape) == [T, d_model // 2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, nn.Linear):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
def forward(self, t):
emb = self.timembedding(t)
return emb
然后这个时间步位置编码会和输入一起输送到每层网络中,ddpm中以resnet为基础模型:
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(32, in_ch),
Swish(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
)
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.GroupNorm(32, out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = AttnBlock(out_ch)
else:
self.attn = nn.Identity()
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
def forward(self, x, temb):
h = self.block1(x)
h += self.temb_proj(temb)[:, :, None, None]
h = self.block2(h)
h = h + self.shortcut(x)
h = self.attn(h)
return h
时间步位置编码和输入分别通过卷积块实现通道一致然后相加作为融合。后续就和普通网络做一样的操作。
然后再看前向过程的代码,这一步是设计扩散的参数
def extract(v, t, x_shape):
"""
Extract some coefficients at specified timesteps, then reshape to
[batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
out = torch.gather(v, index=t, dim=0).float()
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
def forward(self, x_0):
"""
Algorithm 1.
"""
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
noise = torch.randn_like(x_0)
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
model是所选用的网络比如上述提到的resnet。对于训练是比较简单的。每一轮训练都需要获取1-T中的某个时刻t,然后根据t算出相应的。根据参数命名就知道各个参数对应公式的哪个值了。extract就是提取t时刻的参数。
由此不断循环训练迭代,达到收敛后就可以从T时刻开始逆扩散逐步去噪直到t=0.