Diffusion核心原理以及核心代码部分详解

Diffusion 原理及attention代码详解

reference:
https://zhuanlan.zhihu.com/p/642354007
https://zhuanlan.zhihu.com/p/677234407
https://zhuanlan.zhihu.com/p/632809634
https://proceedings.neurips.cc/paper/2020/hash/4c5bcfec8584af0d967f1ab10179ca4b-Abstract.html

0. 原理部分(DDPM )为例

0.1 扩散过程(加噪过程)

请添加图片描述

  • 前向过程的每一步采样可以看作当前图像和高斯噪声的加权组合。加到最后完全变为高斯噪声。
  • T = 1000,扩散1000次,采样1000次。
  • β \beta β 取值 [ 1 0 − 4 , 0.02 ] [10^{-4},0.02] [104,0.02]。也就是说随着时间步的增大,高斯权重的噪声越来越大。

0.2 训练过程

请添加图片描述
根据
x t = α t x t − 1 + 1 − α t ϵ x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1 - \alpha_t}\epsilon xt=αt xt1+1αt ϵ
可以推出 x t x_t xt 根据 x 0 x_0 x0 的高斯采样分布为:
q ( x t ∣ x 0 ) = N ( x t ; α ‾ t x 0 , 1 − α ‾ t I ) q(x_t|x_0) = \mathcal{N}(x_t;{\overline{\alpha}_t}x_0 ,1 - \overline{\alpha}_t I) q(xtx0)=N(xt;αtx0,1αtI)
重采样形式为:
x t = α ‾ t x 0 + 1 − α ‾ t ϵ t x_t = \sqrt{\overline{\alpha}_t}x_0 + \sqrt{1 - \overline{\alpha}_t }\epsilon_t xt=αt x0+1αt ϵt
因此在训练过程中可以得到任意一个 x t x_t xt 来计算l2误差损失。

0.3 采样过程

请添加图片描述
在这里插入图片描述

  • 在扩散过程中, q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)可以根据高斯分布推到出来,也就是知道 x 0 x_0 x0 x t x_t xt 就可以推导出来 x t − 1 x_{t-1} xt1 的分布。
  • 但是在采样过程中,由于不知道 x 0 x_0 x0 因此 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt)是未知的。需要预测噪声进而预测出一个 x 0 x_0 x0带入分布。

0.4 为什么优化目标是误差最小?

一开始优化的目标是负log似然:
− l o g P θ ( x 0 ) -logP_\theta(x_0) logPθ(x0)

然而又因为
请添加图片描述
推理过程:
请添加图片描述

因此只需要优化 L t − 1 L_{t-1} Lt1就好, L T L_T LT L 0 L_0 L0是定值。
让两个高斯分布的KL散度最小,就是让他们的均值和方差最小。而方差 σ t \sigma_t σt 在DDPM中假设为定值 β t \beta_t βt所以不优化。
均值部分为
在这里插入图片描述

推导可以有:
请添加图片描述

1. 核心算法

1.1 采样部分

def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0):
  ddim_scheduler = DDIMScheduler()
  vae = VAE()
  unet = UNet()
  zt = randn(image_shape)
  T = 1000
  timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]

  text_encoder = CLIP()
  c = text_encoder.encode(text)

  for t = timesteps:
    eps = unet(zt, t, c)
    std = ddim_scheduler.get_std(t, eta)
    zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
  xt = vae.decoder.decode(zt)
  return xt
  • ddim_scheduler: 包含了 α , β \alpha,\beta α,β的参数,时间步 t e m p s t e p tempstep tempstep,方差 s t d std std等信息。
  • VAE 是stable diffusion 是将图像转化为隐空间的编码器。
  • unet 负责根据时间步骤 t t t 以及第 t t t时刻的图像状态以及文本约束 c c c 来预测噪声。

1.2 Unet 结构

1.2.1 总体结构示意图:

请添加图片描述

1.2.2 模块细节

请添加图片描述

1.2.3 改进的核心模块

请添加图片描述

  • time embedding 从ResNetBlock 中加入。相加时在H,W维度上进行了广播。
  • context 文本约束从Spatial Transformer 中加入。

2. Attention 部分的修改

2.1 attn_processors字典的修改。

在Unet 中,每一个attention模块都对应一个AttentionProcessor类和实例,通过 Unset 中的 attn_processors 字典维护。
它的key 是attention 的位置或者说是网络模块的名字,在修改时需要修改attn_processors中我们想要修改部分的AttentionProcessor
例如,

attn_procs = {}
    unet_sd = unet.state_dict()
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
            attn_procs[name].load_state_dict(weights)
    unet.set_attn_processor(attn_procs)

2.2 自定义的AttentionProcessor

下面是原始的AttentionProcessor,可以修改其中的逻辑,或者添加新的可以学习的变量。
例如IP-Adapter 中就加入了新的image 的cross-attention。
但是注意,_call_()中的变量不能更改。想要传入新的变量,可以cancate 到encoder_hidden_states 上。

class AttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
  • attn:已经定义好的attention,不可以修改,可以用,也可不用。
  • hidden_states:Transformer Block 的中间变量。
  • encoder_hidden_states:文本编码器,一般是修改这个结构,例如IP-adapter 中在这里拼接了image_embedding。
  • attention_mask 、temb:一般不用。
  • 19
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
stable diffusion 是一种图像处理技术,通过应用不同的滤波器和参数调整,可以达到稳定图像的效果。Dreambooth 是一个用于定制自己的 stable diffusion 的工具。 Dreambooth 的原理是基于稳定扩散的原始算法,通过反复迭代将图像平滑处理,达到消除噪音和增加细节的目的。该算法的主要思想是在滤波器的各个位置上应用聚合函数,以合并邻域内的像素值。图像的每个像素点在该过程中被赋予一个新的值,以确保图像的平滑和细节。 在使用 Dreambooth 进行实战时,首先需要选择一个适合的滤波器类型和参数。常用的滤波器类型包括均值滤波器、中值滤波器等。选择不同的滤波器类型和参数可以得到不同的效果。接下来,将选择的滤波器和参数应用于输入图像,可以使用编程语言如Python来实现相关代码。 以下是一个简单的示例代码,展示了如何使用 Python 和 OpenCV 库来实现 Dreambooth 的效果: ```python import cv2 def dreambooth(image, filter_type, filter_size): blurred_image = cv2.blur(image, (filter_size, filter_size)) # 使用均值滤波器进行图像模糊 detail_image = image - blurred_image # 计算细节图像 result_image = image + detail_image # 合并细节和原始图像 return result_image # 读取输入图像 image = cv2.imread('input_image.jpg') # 设置滤波器类型和大小 filter_type = cv2.MEAN # 均值滤波器 filter_size = 5 # 滤波器大小 # 应用 Dreambooth result = dreambooth(image, filter_type, filter_size) # 显示结果图像 cv2.imshow('Result', result) cv2.waitKey(0) cv2.destroyAllWindows() ``` 通过调整滤波器类型和大小,可以实现不同的图像处理效果。在使用 Dreambooth 进行定制时,可以根据自己的需求和实际情况选择适合的滤波器和参数,以达到最佳的稳定扩散效果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值