主要包括三个模块,Controlnet的采样方法配置,zero123的Unet架构配置,和Hifa的优化框架与参数配置,此篇讲解Controlnet的采样方法配置。
DDIMSampler类
初始化方法:
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
model
: 用于生成图像的扩散模型。schedule
: 采样的调度策略,默认是线性的。ddpm_num_timesteps
: 从模型中获取扩散步骤的数量。
p_sample_ddim
方法:
这是DDIM采样的主要方法,使用Denoising Diffusion Implicit Models进行确定性采样,将随机扩散过程转换为确定的路径。
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
x
: 当前扩散步骤的图像。c
: 条件(比如文本或其他图像特征)。t
: 当前时间步。index
: 当前时间步的索引。repeat_noise
,use_original_steps
,quantize_denoised
,temperature
,noise_dropout
,score_corrector
,corrector_kwargs
,unconditional_guidance_scale
,unconditional_conditioning
,dynamic_threshold
: 采样过程中的各种参数。
主要步骤包括:
- 模型输出计算:
根据是否有无条件引导来计算模型输出。
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
- 预测噪声:
根据模型的参数化方法预测噪声。
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
3.得分修正器:
如果有得分修正器,对噪声进行修正。
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
- 计算采样参数:
根据是否使用原始步骤计算采样参数。
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
- 预测图像:
计算前一时间步的图像。
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
- 应用量化和动态阈值(如果有):
if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) if dynamic_threshold is not None: raise NotImplementedError()
- 计算最终图像:
结合噪声和预测图像,得到前一时间步的图像。
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
2. to_rgb_image
函数
这个函数用于将输入图像转换为RGB格式。
def to_rgb_image(maybe_rgba: Image.Image):
if maybe_rgba.mode == 'RGB':
return maybe_rgba
elif maybe_rgba.mode == 'RGBA':
rgba = maybe_rgba
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
img = Image.fromarray(img, 'RGB')
img.paste(rgba, mask=rgba.getchannel('A'))
return img
else:
raise ValueError("Unsupported image type.", maybe_rgba.mode)
maybe_rgba
: 输入的PIL图像。- 如果图像模式是RGB,直接返回。
- 如果图像模式是RGBA,将其转换为RGB:
- 创建一个随机噪声的RGB图像。
- 将RGBA图像粘贴到RGB图像上,使用其Alpha通道作为掩码。
- 如果图像模式不是RGB或RGBA,抛出错误。
总结
这段代码实现了一个用于图像生成的DDIM采样器类,并提供了一个函数将输入图像转换为RGB格式。这两个功能在图像生成和处理过程中是非常重要的。DDIM采样器可以通过确定性路径生成高质量的图像,而图像格式转换确保输入图像符合模型的要求。