Latent Diffusion Model 代码解读
Author: Sijin Yu
前言
github: https://github.com/CompVis/latent-diffusion.
Latent Diffusion 分为两个训练阶段, 第一阶段训练VAE, 第二阶段训练Diffusion, 代码的组织结构如下:
第一阶段: AutoEncoder 的训练
AutoencoderKL
类
位置: latent-diffusion/ldm/models/autoencoder.py
该类实现一个基于 VAE 的 AutoEncoder.
方法:
init_from_ckpt(self, path, ignore_keys=list())
. 从指定路径加载模型和状态字典. (代码略)encode(self, x)
. 输入x
, 输出一个高斯分布, 返回一个DiagonalGaussianDistribution
对象. (见前向过程)decode(self, z)
. 输入z
, 输出其解码结果, 返回一个torch.tensor
对象. (见前向过程)forward(self, input, sample_posterior=True)
. 前向过程. 先 encode, 再 decode. 返回 decode 结果torch.tensor
对象和 encode 结果高斯分布DiagonalGaussianDistribution
对象. (点击跳转)get_input(self, batch, k)
用于将输入数据处理为合适的形状. (代码略)training_step(self, batch, batch_idx, optimizer_idx)
. 训练. (点击跳转)validation_step(self, batch, batch_idx)
. 测试. (点击跳转)configure_optimizers(self)
. 配置和构造优化器. (代码略)get_last_layer(self)
. 返回模型最后一层的权重. (代码略)log_images(self, batch, only_inputs=False, **kwargs)
. 记录生成的图像. (代码略)to_rgb(self, x)
. 记录生成的分割图像. (代码略)
构造函数
def __init__(self,
ddconfig, # 用于构造Encoder和Decoder的配置参数
lossconfig, # 用于构造损失函数的配置参数
embed_dim, # embedding dim
ckpt_path=None, # 加载预训练模型的路径
ignore_keys=[], # 加载模型时忽略的层
image_key="image", # 输入批次中提取图像数据的键名
colorize_nlabels=None, # 看着没啥用
monitor=None, # 用于监控模型训练过程的对象
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) # 量化z为embedding
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) # 解量化为z
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
传入八个参数:
ddconfig
. 字典, 用于构造 Encoder 和 Decoder 的配置参数.lossconfig
. 字典, 用于构造损失函数的配置参数.
这里的具体代码比较复杂, 简单来说就是: 字典lossconfig
有两个重要 key, 分别是:'target'
. value 为一个字符串, 表示使用哪一个损失函数. 例如一个合法的 value 为torch.nn.CrossEntropyLoss
.'params'
. value 为一个字典, 可以为空, 默认为dict()
. 这个参数将用于构造损失函数.
embed_dim
. Embedding Dim, 嵌入维度.- 其余的和模型构造基本无关.
这里, 13和14行中的:
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
的 Encoder
和 Decoder
的具体代码可见: 点击跳转.
前向过程
前向过程代码如下 :
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
参数 sample_posterior
表示 latent variable z
是采样得来, 还是直接取均值.
这里还涉及 encode
和 decode
过程, 代码分别如下:
def encode(self, x):
h = self.encoder(x) # 通道数为 2*z_channels
moments = self.quant_conv(h) # 通道数 2*z_channels -> 2*embed_dim
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z) # 通道数 2*embed_dim -> 2*z_channels
dec = self.decoder(z)
return dec
注意, 这里为什么通道数都要乘 2? 因为要预测均值和对数方差.
训练过程
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
这里其实就是计算loss, 关键是要搞清楚 self.loss
是什么. 我们去看构造函数的定义:
self.loss = instantiate_from_config(lossconfig)
发现这是一个非常复杂的函数. 大概逻辑就是这将从一个字典的字符串中读取loss对应的对象. 那么我们去看配置文件, latent-diffusion/autoencoder/autoencoder_kl_8x8x64.yaml
, 发现损失函数的配置信息如下:
lossconfig:
target: ldm.modules.losses.LPIPSWithDiscriminator
params:
disc_start: 50001
kl_weight: 0.000001
disc_weight: 0.5
因此, 这里的 self.loss
其实就是 ldm.modules.losses.LPIPSWithDiscriminator
类的对象.
点击跳转 LPIPSWithDiscriminator
类.
测试过程
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
Encoder
和 Decoder
类
位置: latent-diffusion/ldm/modules/diffusionmodules/model.py
Encoder
和 Decoder
的代码非常简单, 就是很经典的网络, 这里不多做解释, 直接上代码.
Encoder
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
Decoder
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
代码中涉及 make_attn(in_channels, attn_type="vanilla")
方法, 代码如下.
Attention
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
h_ = self.proj_out(h_)
return x+h_
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
DiagonalGaussianDistribution
类
位置: latent-diffusion/ldm/modules/distributions/distributions.py
该类表示一个对角高斯分布.
它有四个对外的方法:
sample(self)
. 返回一个服从该分布的随机样本. 点击跳转.kl(self, other=None)
. 计算和另一个高斯分布 (默认为标准高斯分布) 的 KL 散度. 点击跳转.nll(self, sample, dims=[1, 2, 3])
. 计算给定样本的非负对数似然. 点击跳转.mode(self)
. 返回均值. (代码略)
构造函数
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
# 将张量parameters分为两个部分(从dim=1), 并赋值给均值和对数方差
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) # 将对数方差限制在(-30.0, 20.0)这个范围
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar) # 对数方差 -> 标准差
self.var = torch.exp(self.logvar) # 对数方差 -> 方差
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
传入两个参数:
parameters
. 一个torch.tensor
, 表示均值和对数方差.deterministic=False
. 是否有确定性. 如果True
, 则标准差和方差会被置为0
, 分布退化为一个确定的均值.
点击此处返回 DiagonalGaussianDistribution
类.
采样
def sample(self):
# 返回一个这一分布的随机样本
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
点击此处返回 DiagonalGaussianDistribution
类.
KL 散度
类内计算 KL 散度的方法如下, 其中 other
参数传入另一个高斯分布对象, 默认为 None
时, 计算和标准高斯分布之间的 KL 散度.
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
给定均值和方差, 也可以计算两个高斯分布之间的 KL 散度, 代码如下:
def normal_kl(mean1, logvar1, mean2, logvar2):
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
点击此处返回 DiagonalGaussianDistribution
类.
非负对数似然
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
点击此处返回 DiagonalGaussianDistribution
类.
LPIPSWithDiscriminator
类
位置: latent-diffusion/ldm/modules/losses/contperceptual.py
该类用于计算VAE的损失. 损失由四部分组成: (1) 真实图 - 生成图 像素级别的L1损失, (2) 真实图 - 生成图 特征级别的相似度损失, (3) VAE的KL损失, (4) 生成器和鉴别器的损失.
它有两个方法:
-
calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None)
. 计算自适应权重以平衡真实图 - 生成图的损失和生成/鉴别的损失. (代码略) -
forward(self, inputs, reconstructions, posteriors, optimizer_idx, global_step, last_layer=None, cond=None, split="train", weights=None)
. 前向过程, 计算损失. 部分参数:input
. 真实的输入图像.reconstructions
. VAE重构的图像.posteriors
. VAE中间层预测的均值和方差的分布.optimizer_idx
. 一个指示器, 当其为0
时优化生成器,1
时优化鉴别器.
点击跳转.
构造函数
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_loss="hinge"):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
传入的参数:
disc_start
. 用于开始应用鉴别器损失的迭代次数, 影响GAN损失的权重.logvar_init
. 对数方差的初始值, 用于衡量重构损失和正则损失. (下文有详细讨论)kl_weight
. KL损失的权重. (KL损失: VAE的预测高斯分布和标准高斯分布的KL损失, 这一损失也被认为是VAE中的一个正则损失).pixelloss_weight
. 像素损失的权重. 但这个参数在代码中完全没有用到. (像素损失: 真实的图像和生成的图像之间的L1损失).disc_weight
. 生成器/鉴别器损失的权重. (生成/鉴别损失: 对于鉴别器, 要识别真实图像/生成图像; 对于生成器, 要欺骗鉴别起). 这一参数和上面的disc_start
共同影响GAN损失的权重.perceptual_weight
. 感知相似损失的权重. (感知相似损失: 和像素损失类似, 保证真实图像和生成图像相似. 感知损失是把图像放入VGG中, 计算各层的特征, 并计算特征之间的相似性).disc_num_layers
. 鉴别器的层数.disc_in_channels
. 鉴别器的输入通道数.disc_factor
. 控制GAN损失的因子. 它和上面的disc_start
,disc_weight
共同最终决定GAN损失的权重.use_actnorm
. 是否在GAN中使用激活归一化 (ActNorm).disc_conditional
. 鉴别器是否为有条件的.disc_loss
. 鉴别器损失函数的类型.
构造函数第10行中的:
self.perceptual_loss = LPIPS().eval()
的LPIPS
类用于计算两个图像的感知相似度. 点击跳转LPIPS
类.
点击此处返回 LPIPSWithDiscriminator
类.
前向过程
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
global_step, last_layer=None, cond=None, split="train",
weights=None):
# rec_loss为原图和生成图的L1距离
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
# p_loss是LPIPS损失, 由图像的每一层vgg特征之间的相似度计算得来
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
# 乘一个因子self.perceptual_weight来衡量不同损失的重要程度
# 重构损失=L1距离+w*LPIPS损失
rec_loss = rec_loss + self.perceptual_weight * p_loss
# 计算非负对数似然
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar # 这里下文有解释
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# 计算后验分布和标准高斯分布之间的距离
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
# 下面的损失是用于训练GAN部分的
# optimizer_idx有两个取值, 0或1, 0时更新生成器, 1时更新鉴别器
if optimizer_idx == 0:
# 更新生成器
if cond is None: # cond表示是否有条件判别
assert not self.disc_conditional # 无条件判别
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional # 有条件判别
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
# logits_fake是判别器的输出
# 注意我们的输入是reconstructions, 这是假数据, 当前正在训练生成器, 目标是欺骗鉴别器
# 鉴别器: 真数据 ---> 0; 假数据 ---> 1
g_loss = -torch.mean(logits_fake) # 生成器损失
# 下面是给生成器损失乘一个权重, 目的是加强训练生成器
# 当生成器权重<=0.0时, 不再使用生成器
# 生成器只在训练VAE阶段用, 在训练Diffusion阶段不用
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
# 损失 = 重构损失(weighted_nll_oss)+正则KL损失(kl_loss)+生成器损失(g_loss)
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# 更新鉴别器
if cond is None: # 同上, 是否有条件鉴别
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
# 同上, 鉴别器权重
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
# self.disc_loss给出了如何训练鉴别器
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) # 这里下文有解释
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log
鉴别器的损失
self.disc_loss
为何物 ?首先, 看
self.disc_loss
的声明, 在构造函数中:self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
即鉴别器的损失有两种, 这两种损失的代码都非常简单, 如下:
def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1. - logits_real)) loss_fake = torch.mean(F.relu(1. + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))) return d_loss
为什么要
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
?首先, 来看
self.logvar
的声明, 在构造函数中:self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
这是一个可学习的数. 通过将重构误差
rec_loss
正则化为nll_loss
, 允许模型估计重构误差的不确定性. 通过这种方式, 模型可以学习在哪些区域的重构更加困难. 例如, 如果模型认为某个区域的重构更加困难, 可以通过增加该区域的self.logvar
值来降低重构误差的影响, 这有助于模型更加健壮, 更好地应对有噪声的数据.那么, 有读者自然会疑问, 如果只是这样, 为什么不只使用下面的方法呢:
nll_loss = rec_loss / torch.exp(self.logvar)
换言之, 为什么要在后面加上
self.logvar
? 这其实也很容易理解, 我们不希望模型无脑地增加不确定性. 如果我们不加上self.logvar
, 那可能陷入一种这样的情况: 模型无限地增加self.logvar
, 认为重构总是很困难, 最终让重构误差nll_loss
趋于 0, 并只考虑正则化误差. 这显然是不合适的, 因此在后面加上对数方差, 让模型能在两种情况下作出选择.
点击此处返回 LPIPSWithDiscriminator
类.
LPIPS
类
位置: taming/modules/losses/lpips.py
它全称为 Learned Perceptual Image Patch Similarity, 继承 torch.nn.Module
, 用于比较两个图像在感知上的相似度.
它的主要方法有:
load_from_pretrained(self, name="vgg_lpips")
. 用于加载预训练权重. (代码略)from_pretrained(cls, name="vgg_lpips")
. 类方法. 用于加载预训练权重. (代码略)forward(self, input, target)
. 前向过程. 输入input
为原图,target
为生成图, 返回两者在多尺度上的相似度. 点击跳转.
快捷返回 LPIPSWithDiscriminator
类.
构造函数
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
前向过程
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
第二阶段: Latent Diffusion 的训练
LatentDiffusion
类
位置: latent-diffusion/dfm/models/diffusion/ddpm.py
LatentDiffusion
类继承于经典的图像空间上的 DDPM
类, 十分建议先看 DDPM
类的代码. 点击跳转.
LatentDiffusion
类有以下方法:
__init__
. 构造函数.register_schedule
. 注册时间表, 调用DDPM
类中的register_schedule
方法. (代码略)make_cond_schedule
. 在上面的register_schedule
方法中被调用, 指定在扩散过程的哪些时间步骤上应用条件输入. (代码略)on_train_batch_start
. 使用了rank_zero_only
装饰器和torch.no_grad()
装饰器. 只在训练开始时的第一个批次触发, 目的是对隐空间设定一个标准化重缩放因子self.scale_factor
. 这对模型训练的稳定性和性能有益. (代码略)instantiate_first_stage
. 实例化第一阶段的模型 (即AutoEncoder) 并冻结模型参数. (代码略)instantiate_cond_stage
. 实例化条件编码模型 (即CLIP Text Encoder) 并冻结模型参数. (代码略)_get_denoise_row_from_list
. 从提供的样本中生成图像, 并将图像可视化为网格. (代码略)get_first_stage_encoding
. 从第一阶段的Encoder中获得latent variablez
. (代码略)get_learned_conditioning
. 从条件编码器中得到条件编码. (代码略)meshgrid
. 创建一个网格坐标张量. 输入h
和w
, 表示图像的高和宽. 输出一个形状为[h, w, 2]
的torch.tensor
对象, 分别表示每个像素的 y y y 坐标和 x x x 坐标. (代码略)delta_border
. 计算图像中每个像素到图像边缘的归一化距离. 输入h
和w
, 表示图像的高和宽. 输出一个形状为[h, w, 1]
的torch.tensor
对象, 分别表示每个像素到图像边缘的归一化距离. (代码略)get_weighting
. 计算图像每个区域的权重, 中央区域权重大, 边缘区域权重小, 权重根据像素点到图像边缘的归一化距离决定. (代码略)get_fold_unfold
. 将图像打成 patch, 并根据每个区域的权重重置图像像素值. (代码略)get_input
. 使用了torch.no_grad()
装饰器. 处理批量数据, 得到最终的输入, 包括图像x
和条件c
.decode_first_stage
. 解码潜在表示z
. (代码略)differentiable_decode_first_stage
. 这个方法是decode_first_stage
的可微版本, 即允许梯度传递. (代码略)encode_first_stage
. 使用了torch.no_grad()
装饰器. 将图像编码为z
. (代码略)shared_step
. 在一个批量内共享时间步, 执行 Latent Diffusion. (代码略)forward
. 采样并执行反向过程, 返回重建损失. (代码略)_rescale_annotations
. 用于重新缩放图像中的边界框坐标. (代码略)apply_model
. 将带有噪声的图像x_noise
应用于多个块, 每个块应用模型, 然后将它们重新组合为新的图像. (代码略)_predict_eps_from_xstart
. 这个函数无调用. (代码略)_prior_bpd
. 计算扩散最后一个时间步的分布和标准高斯分布之间的KL散度. 这个KL项只依赖于编码器, 它不能通过优化来改变, 它是模型对输入数据进行建模的一个度量. 这个函数无调用. (代码略)p_losses
. 和DDPM
中的p_losses
作用一致. (代码略)p_mean_variance
. 和DDPM
中的p_mean_variance
作用一致. (代码略)p_sample
. 和DDPM
中的p_sample
作用一致. (代码略)p_sample_loop
. 和DDPM
中的p_sample_loop
作用一致. (代码略)progressive_denoising
. 采样并生成最终图像. (代码略)sample
. 和DDPM
中的sample
作用一致. (代码略)
DDPM
类
位置: latent-diffusion/dfm/models/diffusion/ddpm.py
方法: (省略了传入参数)
__init__
. 构造函数. (点击跳转)register_schedule
. 用于计算DDPM中的 β \beta β, α \alpha α 等参数, 以及扩散过程中的分布参数. (点击跳转)ema_scope
. 使用了@contextmanager
装饰器. 用于训练中临时切换到使用指数移动平均权重的模型. (代码略)init_from_ckpt
. 从指定的 checkpoint 读取模型. (代码略)q_mean_variance
. 计算扩散过程中的条件分布 q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0), 返回均值, 方差, 对数方差. (点击跳转)predict_start_from_noise
. 给定带噪音的图像 x t x_t xt, 时间步 t t t, 预测噪音 ϵ ^ \hat\epsilon ϵ^, 计算预测去噪图像 x ^ 0 \hat x_0 x^0. (点击跳转)q_posterior
. 计算后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt−1∣xt,x0), 返回均值, 方差, 对数方差. (点击跳转)p_mean_variance
. 计算反向过程 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt), 返回均值, 方差, 对数方差. (点击跳转)p_sample
. 用于反向过程采样, 给定 x t x_t xt, 通过 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt) 采样 x t − 1 x_{t-1} xt−1. (无梯度) (点击跳转)p_sample_loop
. 用于反向过程采样, 通过 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt) 逐步从 x T x_T xT 得到 x 0 x_0 x0. (无梯度) (点击跳转)sample
. 用于反向过程采样, 通过 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt) 逐步从 x T x_T xT 得到 x 0 x_0 x0. (无梯度) (点击跳转)q_sample
. 用于扩散过程采样, 通过 q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0) 采样 x t x_t xt. (点击跳转)get_loss
. 计算UNet预测的噪音和真实噪音之间的损失. (点击跳转)p_losses
. 计算引入了变分下界损失VLB后的预测噪音和真实噪音之间的损失. (点击跳转)forward
. 前向过程. 输入原始图像 x 0 x_0 x0, 输出扩散损失. (点击跳转)get_input
. 处理输入图像为合适的形状. (代码略)shared_step
. 读取批量的输入图像, 然后执行前向过程, 得到损失, 让一个批量内的所有样本共享时间步. (代码略)training_step
. 训练. (点击跳转)validation_step
. 预测. (点击跳转)on_train_batch_end
. 这个函数看着没啥意义, 在代码中也没调用, 可以忽略._get_rows_from_list
. 这个函数用于修改一些样本的形. (代码略)log_images
. 用来将生成的图像记录日志. (无梯度) (代码略)configure_optimizers
. 用来配置优化器, 用了torch.optim.AdamW
. (代码略)
构造函数
def __init__(self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.,
):
super().__init__()
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.image_size = image_size # try conv?
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
参数:
unet_config
. 字典, UNet 的配置参数.timesteps
. 扩散模型的总时间步数. (默认值1000
)beta_schedule
. 扩散过程中噪声水平 β t \beta_t βt 的调整策略. (默认值'linear'
, 线性增大 β t \beta_t βt)loss_type
. 计算噪声预测误差的损失函数的类型. (默认值'l2'
, 均方误差)ckpt_path
. 加载 checkpoint 文件的路径. (默认值None
, 不加载)ignore_keys
. 使用 checkpoint 加载模型时忽略的键列表. (默认值[]
, 不忽略)load_only_unet
. 是否只加载 UNet 的权重. (默认值False
)monitor
. 在训练过程中用于监控模型好坏的指标. (默认值'val/loss'
, 测试损失)use_ema
. 是否使用指数移动平均 (EMA) 来平滑模型参数. (默认值True
)first_stage_key
. 第一阶段模型中使用的键名. (默认值'image'
)image_size
. 图像的大小. (默认值256
, 图像大小为 256 × 256 256\times256 256×256)channels
. 图像的通道数. (默认值3
)log_every_t
. 在生成过程中每隔多少时间步 t t t 记录一次图片. (默认值100
)clip_denoised
. 是否将噪音裁剪至 ( − 1.0 , 1.0 ) (-1.0, 1.0) (−1.0,1.0) 区间. (默认值True
)linear_start
. β 0 \beta_0 β0 的值. (默认值1e-4
)linear_end
. β T \beta_T βT 的值. (默认值2e-2
)cosine_s
. 只在使用余弦增加 β t \beta_t βt 时有效, 控制余弦增大的参数. (默认值8e-3
)given_betas
. 直接给定一组 [ β t ] t = 0 T [\beta_t]_{t=0}^T [βt]t=0T. (默认值None
)original_elbo_wight
. 损失函数中使用原始证据下界 (ELBO) 的权重 (默认值0.
)v_posterior
. 用于选择后验方差的权重 v v v. σ t = ( 1 − v ) β ~ t + v β t \sigma_t=(1-v)\tilde\beta_t+v\beta_t σt=(1−v)β~t+vβt. (默认值0.
)l_simple_weight
. 简单损失的权重. (默认值1.
)conditioning_key
. 使用条件生成时, 条件数据的键. (默认值None
)parameterization
. 模型参数化的方式, 即 UNet 预测原始图像还是噪声. (默认值'eps'
)scheduler_config
. 字典, 优化器的配置参数. (默认值None
)use_positional_encodings
. 是否使用位置编码. (默认值False
)learn_logvar
. 是否学习对数方差的参数. (默认值False
)logvar_init
. 对数方差的初始值. (默认值0.
)
我们看构造函数有一行:
self.model = DiffusionWrapper(unet_config, conditioning_key)
这个其实就是 UNet,
DiffusionWrapper
类就是实现有条件 diffusion 和无条件 diffusion 的, 它的代码如下:class DiffusionWrapper(pl.LightningModule): def __init__(self, diff_model_config, conditioning_key): super().__init__() self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): if self.conditioning_key is None: out = self.diffusion_model(x, t) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t) elif self.conditioning_key == 'crossattn': cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) elif self.conditioning_key == 'adm': cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) else: raise NotImplementedError() return out
自然地, 我们的目光再次回到了
instantiate_from_config
函数. 我们要看配置文件的描述. 在latent-diffusion/config/latent-diffusion/celebahq-ldm-vq-4.yaml
中, 找到:unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel
因此, 我们去看
UNetModel
类. (点击跳转)
构造函数中还有这样一行:
if self.use_ema: self.model_ema = LitEma(self.model)
这里是使用指数移动平均 (EMA) 来平滑模型参数. (代码略)
注册 β \beta β 和 α \alpha α 时间表
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas # 给定beta
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s) # 计算beta
alphas = 1. - betas # 这是定义, \alpha_t = 1 - \beta_t
alphas_cumprod = np.cumprod(alphas, axis=0) # 计算 \prod_{j}^{i}\alpha_{i}
# cumprod用于计算数组元素的乘积, 返回一个新的数组, 每个元素是到目前为止所有元素的乘积
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) # 去掉最后一个元素, 在最前面加1
timesteps, = betas.shape # 总时间步数
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
# to_torch(x) 等价于: torch.tensor(x, dtypye=torch.float32)
# 下面这些都是用来计算中间参数的, 用于给不同的函数直接调用下面的这些参数
# register_buffer是torch.nn.Module的一个方法, 用于将一个tensor添加到模型的缓冲区
# 缓冲区不会被视作模型参数, 不参与梯度更新
# DDPM中的\beta_t数组:
self.register_buffer('betas', to_torch(betas))
# DDPM中的\bar\alpha_t数组:
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
# 为了在计算涉及到前一个时间步的公式时方便引用:
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# alphas_cumprod_prev[t] 实际上表示的是在第 t-1 个时间步后,原始信号剩余的比例
# calculations for diffusion q(x_t | x_{t-1}) and others
# 计算 \sqrt{\bar\alpha_t}:
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
# 计算 \sqrt{(1-\bar\alpha_t)}:
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
# 计算 \log{(1-\bar\alpha_t)}:
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
# 计算 \sqrt{\frac{1}{\bar\alpha_t}}:
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
# 计算 \sqrt{\frac{1}{\bar\alpha_t} - 1}:
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
快捷返回: DDPM
类 | 前向过程 (加噪) | 预测原始图像 | 前向过程 (加噪) 的后验分布 | 反向过程 (去噪) 的损失.
前向过程 (加噪)
def q_mean_variance(self, x_start, t):
# 计算 q(x_t | x_0)
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
这个函数是计算DDPM中的这一公式:
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
q(x_t|x_0)=\mathcal N(x_t;\sqrt{\bar\alpha_t}x_0, (1-\bar\alpha_t)\mathbf I)
q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
输入:
x_start
. 原始图像 x 0 x_0 x0.t
. 时间步 t t t.
返回:
mean
. 分布 q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0) 的均值.variance
. 分布 q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0) 的方差.log_variance
. 分布 q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0) 的对数方差.
这里: (这些在上面的 register_schedule
方法中定义, 点击跳转)
self.sqrt_alphas_cumprod
代表 α ˉ t \sqrt{\bar\alpha_t} αˉt 数组.self.alphas_cumprod
代表 α ˉ t \bar\alpha_t αˉt 数组.self.log_one_minus_alphas_cumprod
代表 log ( 1 − α ˉ t ) \log{(1-\bar\alpha_t)} log(1−αˉt) 数组.
这里的 extract_into_tensor(a, t, x_shape)
表示从数组 a
中拿取第 t
个元素, 并 reshape 为兼容 x_shape
的形状的 torch.tensor
对象.
预测原始图像
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
这个函数是计算DDPM中的这一公式:
x
^
0
=
1
α
ˉ
t
x
t
−
1
α
ˉ
t
−
1
⋅
ϵ
^
\hat x_0=\sqrt{\frac{1}{\bar\alpha_t}} x_t-\sqrt{\frac{1}{\bar\alpha_t}-1}\cdot\hat\epsilon
x^0=αˉt1xt−αˉt1−1⋅ϵ^
输入:
x_t
. 带噪音的图像 x t x_t xt.t
. 时间步 t t t.noise
. 预测的噪音 ϵ ^ \hat\epsilon ϵ^.
输出:
- 预测的原始图像 x ^ 0 \hat x_0 x^0.
这里: (这些在上面的 register_schedule
方法中定义, 点击跳转)
self.sqrt_recip_alphas_cumprod
代表 1 / α ˉ t \sqrt{1/\bar\alpha_t} 1/αˉt 数组.self.sqrt_recipm1_alphas_cumprod
代表 1 / α ˉ t − 1 \sqrt{1/\bar\alpha_t-1} 1/αˉt−1 数组.
这里的 extract_into_tensor(a, t, x_shape)
表示从数组 a
中拿取第 t
个元素, 并 reshape 为兼容 x_shape
的形状的 torch.tensor
对象.
前向过程 (加噪) 的后验分布
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
这个函数是计算DDPM中的这一公式:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
μ
~
t
(
x
t
,
x
0
)
,
β
~
t
I
)
q(x_{t-1}|x_t, x_0)=\mathcal N(x_{t-1};\tilde\mu_t(x_t,x_0),\tilde\beta_t\mathbf I)
q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)
其中,
μ
~
t
(
x
t
,
x
0
)
:
=
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
+
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
\tilde \mu_t(x_t, x_0):=\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0+\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}x_t
μ~t(xt,x0):=1−αˉtαˉt−1βtx0+1−αˉtαt(1−αˉt−1)xt
β ~ t : = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde\beta_t:=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t β~t:=1−αˉt1−αˉt−1βt
输入:
x_start
. 原始图像 x 0 x_0 x0.x_t
. 带噪音图像 x t x_t xt.t
. 时间步 t t t.
输出:
posterior_mean
. 后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt−1∣xt,x0) 的均值.posterior_variance
. 后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt−1∣xt,x0) 的方差.posterior_log_variance_clipped
. 后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt−1∣xt,x0) 的对数方差.
这里: (这些在上面的 register_schedule
方法中定义, 点击跳转)
self.posterior_mean_coef1
代表 α ˉ t − 1 β t 1 − α ˉ t \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t} 1−αˉtαˉt−1βt 数组.self.posterior_mean_coef2
代表 α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t \frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t} 1−αˉtαt(1−αˉt−1) 数组.self.posterior_variance
代表 1 − α ˉ t − 1 1 − α ˉ t β t \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t 1−αˉt1−αˉt−1βt 数组.self.posterior_log_variance_clipped
代表 log ( 1 − α ˉ t − 1 1 − α ˉ t β t ) \log\left(\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t\right) log(1−αˉt1−αˉt−1βt) 数组.
这里的 extract_into_tensor(a, t, x_shape)
表示从数组 a
中拿取第 t
个元素, 并 reshape 为兼容 x_shape
的形状的 torch.tensor
对象.
快捷返回: 点击此处返回 DDPM
类 | 反向过程 (去噪)
反向过程 (去噪)
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
# 模型预测的是噪音
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
# 模型预测的是去噪图像
x_recon = model_out
if clip_denoised:
# 是否将图像的值裁剪到(-1.0, 1.0)区间
x_recon.clamp_(-1., 1.)
# 计算后验分布 q(x_{t-1} | x_t, \hat x_0), 用这个分布估计分布 p_{\theta}(x_{t-1}|x_t)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
这个函数是计算DDPM中的这一公式:
p
θ
(
x
t
−
1
∣
x
t
)
:
=
N
(
x
t
−
1
;
μ
θ
(
x
t
,
t
)
,
Σ
θ
(
x
t
,
t
)
)
p_{\theta}(x_{t-1}|x_t):=\mathcal N(x_{t-1};\mu_{\theta}(x_t, t), \Sigma_{\theta}(x_t, t))
pθ(xt−1∣xt):=N(xt−1;μθ(xt,t),Σθ(xt,t))
输入:
x
. 当前时间步的带噪音图像 x t x_t xt.t
. 时间步 t t t.clip_denoised
. 是否将图像的值裁剪到 ( − 1.0 , 1.0 ) (-1.0, 1.0) (−1.0,1.0) 区间.
输出:
model_mean
. 模型预测的均值 μ θ ( x t , t ) \mu_{\theta}(x_t, t) μθ(xt,t).posterior_variance
. 模型预测的方差 Σ θ ( x t , t ) \Sigma_{\theta}(x_t, t) Σθ(xt,t).posterior_log_variance
. 模型预测的对数方差 log ( Σ θ ( x t , t ) ) \log(\Sigma_{\theta}(x_t, t)) log(Σθ(xt,t)).
这里的 self.q_posterior
是用前向过程的后验分布来近似反向过程的分布, 定义见: 点击跳转.
快捷返回: 点击此处返回 DDPM
类
采样图像
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise) # 返回一个和x一样形状的标准高斯噪音noise
# repeat_noise表示是否重复使用一个噪音, 若重复使用, 一个batch内的所有样本将加同一个随机噪音; 否则每个样本将独立采样
# nonzero_mask表示是否有噪音, t=0时无噪音(为0), 其它时候有噪音(为1)
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
函数 p_sample
是在分布
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}(x_{t-1}|x_t)
pθ(xt−1∣xt) 中采样一个
x
t
−
1
x_{t-1}
xt−1.
输入:
x
. 当前时间步的样本 x t x_t xt.t
. 当前时间步 t t t.clip_denoised
. 是否对噪音裁剪到区间 ( − 1.0 , 1.0 ) (-1.0, 1.0) (−1.0,1.0) 内.repeat_noise
. 是否在一个批量中对所有样本重复使用同一个噪音.
输出:
- 和
x
相同形状的, 下一个时间步中的一个批量的样本 x t − 1 x_{t-1} xt−1.
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
# i从T-1到0
# t = torch.full((b,), i, device=device, dtype=torch.long)
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
# img是x_i
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img) # 存中间的图像
if return_intermediates:
return img, intermediates
return img # x_0
输入:
shape
. 图像的形状.return_intermediates
. 是否返回反向过程中的中间图像.
输出:
img
. 生成的图像 x 0 x_0 x0.intermediates
. 一个列表, 存了中间图像.
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
给定 batch_size
, 直接生成一个样本
x
0
x_0
x0.
模拟扩散过程
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
# 如果noise不是None, 直接返回noise, 否则生成一个和x_start一样形状的noise
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
这个函数通过DDPM扩散过程:
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
q(x_t|x_0)=\mathcal N(x_t;\sqrt{\bar\alpha_t}x_0, (1-\bar\alpha_t)\mathbf I)
q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
在分布
q
(
x
t
∣
x
0
)
q(x_t|x_0)
q(xt∣x0) 中采样一个
x
t
x_t
xt.
输入:
x_start
. 原始的图像 x 0 x_0 x0.t
. 时间步 t t t.noise
. 噪音, 如果为None
则默认为和x_start
一样的标准高斯噪音样本.
输出:
- 和
x_start
相同形状的样本 x t x_t xt.
噪音预测损失
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
这个函数用于计算噪音的预测损失. 代码非常简单, 不多解释.
输入:
pred
. 预测的噪音.target
. 真实的噪音.mean
. 是否将噪音平均为标量.
输出:
loss
. 损失.
反向过程 (去噪) 的损失
def p_losses(self, x_start, t, noise=None):
# noise是当前时间步t加入的噪音
noise = default(noise, lambda: torch.randn_like(x_start))
# x_noisy是从x_0开始执行加噪过程采样得到的样本
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
# model_out是UNet预测的时间步t加入的噪音
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
else:
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
# 计算模型预测噪音(或者图像)的损失
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = 'train' if self.training else 'val'
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
# 简单损失 = loss * 权重
loss_simple = loss.mean() * self.l_simple_weight
# 变分下界(VLB)损失 = 时间步t对应的权重 * loss
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
# 总损失 = 简单损失 + 原始ELBO损失权重 * 变分下界(VLB)损失
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f'{log_prefix}/loss': loss})
return loss, loss_dict
该函数以
x
0
x_0
x0 和
t
t
t 为输入, 先执行扩散过程
q
(
x
t
∣
x
0
)
q(x_t|x_0)
q(xt∣x0), 采样得到一个噪音图像样本
x
t
x_t
xt. 然后用 UNet 去预测时间步
t
t
t 中加入的噪音, 并计算预测损失. 这个计算根据 DDPM 中的下面公式:
L
=
E
x
0
∼
q
(
x
0
)
,
ϵ
∼
N
(
0
,
I
)
[
β
t
2
2
β
~
t
α
t
(
1
−
α
ˉ
t
)
∣
∣
ϵ
−
ϵ
θ
(
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
)
∣
∣
2
]
L=\mathbb E_{x_0\sim q(x_0), \epsilon\sim\mathcal N(0, \mathbf I)}\left[\frac{\beta_t^2}{2\tilde\beta_t\alpha_t(1-\bar\alpha_t)}\left|\left|\epsilon-\epsilon_{\theta}(\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon)\right|\right|^2\right]
L=Ex0∼q(x0),ϵ∼N(0,I)[2β~tαt(1−αˉt)βt2
ϵ−ϵθ(αˉtx0+1−αˉtϵ)
2]
输入:
x_start
. 原始图像 x 0 x_0 x0.t
. 时间步 t t t.nosie
. 噪音样本.
输出:
loss
. 总的损失.loss_dict
. 记录损失的字典. (用于日志)
这里, self.lvlb_weights
表示
β
t
2
2
β
~
t
α
t
(
1
−
α
ˉ
t
)
\frac{\beta_t^2}{2\tilde\beta_t\alpha_t(1-\bar\alpha_t)}
2β~tαt(1−αˉt)βt2 数组, 在上面的 register_schedule
方法中定义, 点击跳转.
前向过程
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
# 随机一个时间步, 在(0, T)之间随机
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
# 返回预测噪音损失.
return self.p_losses(x, t, *args, **kwargs)
输入: x
原始图像.
输出: 损失, 损失字典.
训练
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
训练代码非常简单, 不多解释.
测试
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
测试代码非常简单, 不多解释.
UNetModel
类
位置: latent-diffusion/modules/diffusionmodules/openaimodel.py
这个类实现了UNet. 主要只有两个方法: __init__
和 forward
. 模型结构都写在构造函数里了, 为了简单起见, 我们先看前向过程 forward
.
前向过程
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = [] # 用于存储各层的feature map, 做UNet里的skip connection
# 计算timestep embedding
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
# 计算类别embedding
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks: # UNet的下采样过程
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks: # UNet的上采样过程
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
这个代码十分简单直接, 我们简单梳理下:
输入:
x
. UNet的图像输入 x t x_t xt.timesteps
. 时间步 t t t.context
. 用于互注意力的条件.y
. 图像的类别条件, 即标签.
输出:
- UNet的输出. 在Diffusion里, 这可以是对原图像的预测, 也可以是对噪音的预测.
在代码中, 有几个重要的东西:
-
timestep_embedding
函数. 根据给定的时间步timesteps
得到一个 time embedding, 使用余弦编码. (点击跳转) -
self.time_embed
. 将余弦编码的 time embedding 线性映射为最终的 time emebdding, 让模型自己去学习 embedding. 代码如下:self.time_embed = nn.Sequential( nn.Linear(model_channels, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), )
-
self.label_emb
. 将图像one-hot标签y
映射为 label embedding. 代码如下:self.label_emb = nn.Embedding(num_classes, time_embed_dim)
下面的这三个都用到了 TimestepEmbedSequential
类, 点击跳转.
self.input_blocks
. UNet的下采样过程. (点击跳转)self.middle_block
. UNet的中间层. (点击跳转)self.output_blocks
. UNet的上采样过程. (点击跳转)
下采样
下面的函数是构造函数的片段.
self._feature_size = model_channels
input_block_chans = [model_channels] # 存储下采样每一层的通道数
ch = model_channels
ds = 1
# channel_mult表示了每个下采样层的通道倍数
for level, mult in enumerate(channel_mult):
# 对每个下采样层, 有num_res_blocks个ResBlock
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
# 这一分辨率是否需要attention
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
# 是否下采样的最后一个级别
if level != len(channel_mult) - 1:
# 不是, 因此要做下采样
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown # resblock_updown表示是否使用ResBlock做上采样/下采样
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2 # 更新分辨率
self._feature_size += ch
这里涉及几个类:
中间层
# 中间层: ResBlock -> AttentionBlock -> ResBlock
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
这里涉及几个类:
上采样
上采样和下采样的代码非常相似, 其实就是逆过程, 代码不多解释了, 如下:
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
# 将通道倒过来
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
timestep_embedding
函数
位置: latent-diffusion/modules/diffusionmodules/util.py
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
TimestepEmbedSequential
类
位置: latent-diffusion/modules/diffusionmodules/openaimodel.py
TimestepEmbedSequential
继承了 torch.nn.Sequential
类. 它可以很方便地在模型中加入 timestep 和 condition. 代码如下:
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x
这里的 TimestepBlock
类是一个很简单的抽象类: 点击跳转.
SpatialTransformer
类是将条件与图像做cross-attention的类: 点击跳转.
快捷返回: UNetModel
类 | UNetModel
下采样 | UNetModel
中间层.
TimestepBlock
类
位置: latent-diffusion/modules/diffusionmodules/openaimodel.py
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
点击返回 TimestepEmbedSequential
类.
SpatialTransformer
类
位置: latent-diffusion/modules/attention.py
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
这里的 BasicTransformerBlock
是很经典的 Transformer, 代码为:
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
这里, CrossAttention
类的代码为:
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
而 FeedForward
的代码为:
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
点击返回 TimestepEmbedSequential
类.
ResBlock
类
位置: latent-diffusion/modules/diffusionmodules/openaimodel.py
该类实现了一个基本的带残差连接的块, 代码比较简单, 不多解释:
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
快捷返回: UNetModel
下采样 | UNetModel
中间层.
AttentionBlock
类
位置: latent-diffusion/modules/diffusionmodules/openaimodel.py
该类实现了一个Attention块. 代码如下:
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
这里, QKVAttention
类的定义如下:
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
QKVAttentionLegacy
类的定义如下:
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
快捷返回: UNetModel
下采样 | UNetModel
中间层.
Downsample
类
位置: latent-diffusion/modules/diffusionmodules/openaimodel.py
该类实现了在UNet中的下采样模块. 代码如下:
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)