为什么channels是4?
这个first_stage_key:jpg是要生成的风格的图吗?
control_stage_config:
ControlNet的参数配置信息
把control信息映射为feature map,和latent z一样的大小
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
self.diffusion_model = instantiate_from_config(diff_model_config)
# diff_model_config = unet_config
self.conditioning_key = conditioning_key
diff_model_config是怎么来的,在哪里的设置?
在ddpm.py中的line92中有self.model = DiffusionWrapper(unet_config, conditioning_key)
ControlLDM的get_input
LatentDiffusion的get_input
DDPM(pyLightning)的get_input
batch_size=4, dataloader实例会把数据集中的数据分批次提供给模型,每次提供batch_size个样本,通过使用DataLoader,可以简化模型训练过程中的数据管理和处理
first_stage_key=jpg这个first_stage就是把原始图片经过VAE的encoder变为latent z
self.model = DiffusionWrapper(unet_config, conditioning_key)
ControlNet
ControlNet其实就是和ldm/modules/diffusionmodules/openaimodel.py的UnetModel的input_block, middle_block是一样的,只是它多加入了zero_conv的操作
control = self.control_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
def forward(self, x, hint, timesteps, context, **kwargs):
'''
x: 加了噪声的latent z
'''
# 把时间t编码为vector
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb) # linear->relu->linear emb.shape=(B, time_embed_dim)
# 只会对 hint 进行操作,因为这里的input_hint_block里面的类型不是TimestepBlock和SpatialTransformer
# 那相当于这里只对 hint 进行conv->silu->conv->silu->...->conv->silu->zero_conv
guided_hint = self.input_hint_block(hint, emb, context)
outs = []
h = x.type(self.dtype) # 将 x 的数据类型转换为 self.dtype。
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context) # conv_nd(h)
h += guided_hint # x + conv_nd(hint)
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context)) # 对加入hint后的h再次conv_nd
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context)) # make_zero_conv
return outs # ControlNet的输出,即zero_conv->ldm->zero_conv->outs
- Unet的input_block:
横着进行 ResBlock,AttentionBlock -> zero_conv
竖着进行ResBlock, Downsample->zero_conv
-
Unet的middle_block:
ResBlock, AttentionBlock, ResBlock
self.middle_block_out = self.make_zero_conv(ch)
hint 经过self.input_hint_block(conv->silu->conv->silu->…->conv->silu->zero_conv), 得到guided_hint
即从[b,c,h,w]变成了feature map, 又经过zero_conv
x先经过neural network block, 然后和 guided_hint相加,然后将guided_hint清零
再经过zero_conv,并将此control block的东西append到outs中
controlNet的输出是每个block经过zero_conv后的输出,包括input_block的各个block的zero_conv输出,和middle_block经过zero_conv的输出。
controlNet返回的outs应该是一个列表,包含input_block和middle_block各个block的zero_conv的输出
总结:
ControlNet就是为了获得input_block,middle_block的各个block 的zero_conv的输出!
diffusion_model = self.model.diffusion_model=instantiate_from_config(unet_config)
而unet_config = cldm.cldm.ControlledUnetModel
ControlledUnetModel
controlledUnetModel继承自UNetModel
class ControlledUnetModel(UNetModel):
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
hs = [] # hs 用于保存降采样的每一层的输出,这样在上采样的时候方便进行skip-connection
with torch.no_grad(): # Fig3. Locked copy : no grad
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb) # 经过linear
h = x.type(self.dtype) # self.dtype表示所需的目标数据类型 ; type()是PyTorch张量的一个方法,用于将张量转换为指定的数据类型。传递给type()的参数是所需的目标数据类型。
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
# 上面的是从 Locked Encoder(DownSample) + Middle 出来的 y
if control is not None:
h += control.pop() # y + ControlNet's ouptut
for i, module in enumerate(self.output_blocks): # Decoder(UpSample)
if only_mid_control or control is None: #
h = torch.cat([h, hs.pop()], dim=1)
else:
h = torch.cat([h, hs.pop() + control.pop()], dim=1) # 这里要弄清楚control里面是什么,为什么还可以 pop
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)
对于context,也就是conditon
弄清楚,这个condition是怎么加到z_noise中的呢?
是cat还是直接+呢?
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
# hs:方便以后的skip connection
h = self.middle_block(h, emb, context)
对于control.pop, 也就是ControlNet的输出,
弄清楚,这个ControlNet的输出是如何加到z_noise中的呢?
在cldm.py中line 60:
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context)
这个controlNet的输出是先和skip_connection的部分直接相加,然后和h进行cat,通道上的拼接,之后再进行Up sample操作,对于每个block皆是如此
最后到达unet的右边上端后,经过self.out(h),输出预测的epsilon噪声。
总结:
ControlledUnetModel如果没有control信息,就相当于正常的LDM流程;如果有control信息就把control信息经过controlNet的架构一个个地加入到upsample的过程中去;也就是论文中的Figure 3.的整体架构
DDPM
DDPM和UNetModel 的关系
class DDPM(py.LightningModule):
self.model = DiffusionWrapper(unet_config, conditioning_key)
def q_mean_variance(self, x_start, t):
def predict_start_from_noise(self, x_t, t, noise)
def predict_eps_from_z_and_v(self, x_t, t, v)
class DiffusionWrapper(pl.LightningModule):
# unet_model的包裹体,即将条件也加入到扩散模型中去
self.diffusion_model = instantiate_from_config(diff_model_config)
# diff_model_config = unet_config = ControlledUnetModel(UNetModel)
DDPM和UNetModel 的关系:
-
DDPM的self.model是DiffusionWrapper,
-
DiffusionWrapper的self.diffusion_model是ControlledUnetModel,
-
ControlUnetModel又继承自UNetModel
ControlLDM
ControlLDM( LatentDiffusion( DDPM( pl.LightningModule)))
ControlLDM将整个流程串起来
get_input(self, batch, k , bs=None):
接收数据batch,获得latent z,和编码为feature map的condition;但control还是原始数据形式,[b, c, h, w]
{x, condition, control}
return x, dict(c_crossattn=[c], c_concat=[control])
该字典包含条件信息 c(用于交叉注意力)和控制信息 control(用于连接操作)。
**apply_model(self, x_noisy, t, cond): **
x_noisy: 加噪后的latent z
t: 时间信息
cond: 是一个字典{“c_crossattn”:[c], “c_concat”: [control]}包含condition和control两个信息
diffusion_model就是UNetModel
如果没有control信息:
# 预测noise
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
如果有control信息:
# 先获取control列表,即获取input_blocks和middle_block的每个zero_conv的输出并保存在control中
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
# 预测noise
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
总结:
ControlLDM就是项目总负责人,将整个流程串起来