基于DDIM的图像采样与格式转换:实现与应用详解

主要包括三个模块,Controlnet的采样方法配置,zero123的Unet架构配置,和Hifa的优化框架与参数配置,此篇讲解Controlnet的采样方法配置。

DDIMSampler类

初始化方法:

def __init__(self, model, schedule="linear", **kwargs):
    super().__init__()
    self.model = model
    self.ddpm_num_timesteps = model.num_timesteps
    self.schedule = schedule
  • model: 用于生成图像的扩散模型。
  • schedule: 采样的调度策略,默认是线性的。
  • ddpm_num_timesteps: 从模型中获取扩散步骤的数量。

p_sample_ddim方法:

这是DDIM采样的主要方法,使用Denoising Diffusion Implicit Models进行确定性采样,将随机扩散过程转换为确定的路径。

@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                  unconditional_guidance_scale=1., unconditional_conditioning=None,
                  dynamic_threshold=None):
  • x: 当前扩散步骤的图像。
  • c: 条件(比如文本或其他图像特征)。
  • t: 当前时间步。
  • index: 当前时间步的索引。
  • repeat_noise, use_original_steps, quantize_denoised, temperature, noise_dropout, score_corrector, corrector_kwargs, unconditional_guidance_scale, unconditional_conditioning, dynamic_threshold: 采样过程中的各种参数。

主要步骤包括:

  1. 模型输出计算:

根据是否有无条件引导来计算模型输出。

if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    model_output = self.model.apply_model(x, t, c)
else:
    model_t = self.model.apply_model(x, t, c)
    model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
    model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
  1. 预测噪声:

根据模型的参数化方法预测噪声。

if self.model.parameterization == "v":
    e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
    e_t = model_output

3.得分修正器:

如果有得分修正器,对噪声进行修正。

if score_corrector is not None:
    assert self.model.parameterization == "eps", 'not implemented'
    e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
  1. 计算采样参数:

根据是否使用原始步骤计算采样参数。

alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
  1. 预测图像:

计算前一时间步的图像。

if self.model.parameterization != "v":
    pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
    pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
  1. 应用量化和动态阈值(如果有):
    if quantize_denoised:
        pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
    
    if dynamic_threshold is not None:
        raise NotImplementedError()
    

  2. 计算最终图像:

结合噪声和预测图像,得到前一时间步的图像。

dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
    noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0

2. to_rgb_image函数

这个函数用于将输入图像转换为RGB格式。

def to_rgb_image(maybe_rgba: Image.Image):
    if maybe_rgba.mode == 'RGB':
        return maybe_rgba
    elif maybe_rgba.mode == 'RGBA':
        rgba = maybe_rgba
        img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
        img = Image.fromarray(img, 'RGB')
        img.paste(rgba, mask=rgba.getchannel('A'))
        return img
    else:
        raise ValueError("Unsupported image type.", maybe_rgba.mode)
  • maybe_rgba: 输入的PIL图像。
  • 如果图像模式是RGB,直接返回。
  • 如果图像模式是RGBA,将其转换为RGB:
    • 创建一个随机噪声的RGB图像。
    • 将RGBA图像粘贴到RGB图像上,使用其Alpha通道作为掩码。
  • 如果图像模式不是RGB或RGBA,抛出错误。

总结

这段代码实现了一个用于图像生成的DDIM采样器类,并提供了一个函数将输入图像转换为RGB格式。这两个功能在图像生成和处理过程中是非常重要的。DDIM采样器可以通过确定性路径生成高质量的图像,而图像格式转换确保输入图像符合模型的要求。

  • 6
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值