ControlNet原理+代码解析

ControlNet: TL;DR

ControlNet 是在 Lvmin Zhang 和 Maneesh Agrawala 的 Adding Conditional Control to Text-to-Image Diffusion Models 中引入的。 它引入了一个框架,允许支持各种空间环境,这些环境可以作为Diffusion模型(如 Stable Diffusion)的Additional conditionings

训练 ControlNet 由以下步骤组成:(每一种新的condition都需要重新训练一个新的 ControlNet 权重副本

  1. 复制 Diffusion 模型的预训练参数,例如 Stable Diffusion 的Latent UNet(称为“trainable copy”),同时单独维护原始的预训练参数(“locked copy”): locked 的参数副本可以保留从大型数据集中学到的大量先验知识,而 trainable 参数副本则用于学习特定于任务的方面。

  2. 参数的 trainable 和 locked 副本通过“zero convolution”层连接,这些层作为 ControlNet 框架的一部分进行了优化,从0 逐步增长参数值,确保开始时没有随机的噪音会干扰finetuning。这是一种训练技巧,用于在训练新condition时保留冻结模型已经学习的语义。
    在这里插入图片描述
    推理时,需要Pre-trained的SD扩散模型权重Finetune过的 ControlNet 权重。与仅使用原始 Stable Diffusion 模型相比,将 Stable Diffusion v1-5 与 ControlNet 检查点一起使用需要大约 7 亿个参数,这使得 ControlNet 在推理时需要的内存成本更高。在使用不同的condition时,只需切换 ControlNet 参数。这使得在一个应用程序中部署多个 ControlNet 权重变得相当简单。

Control Type

Canny、Depth、Openpose、Normal、Seg、Scribble、Mlsd、Hed等15类
在这里插入图片描述

  • 首先是最重要的线稿lineart预处理器用于提取线稿图lineart image的模型(分为通用动漫);线稿模型就是线稿controlnet的权重(分为通用动漫)。注意【动漫预处理器(得到的图像会有小方格) + 通用线稿模型(通用controlnet不能很好处理小方格)】效果不好。
    在这里插入图片描述
    首先介绍通用线稿预处理器lineart_realisticlineart_coarse 可以提取各种类型图像的lineart,如2D动画、3D真人、3D动画(即2.5D)
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    然后是动漫线稿预处理器lineart_animelineart_anime_denosise,其中denosise指降噪后的线稿图,可以让prompt有更多的发挥空间。
    在这里插入图片描述

  • 边缘线Canny:controlnet='lllyasviel/sd-controlnet-canny',黑色背景上有白边的单色图像。
    在这里插入图片描述

  • 深度图Depth:lllyasviel/sd-controlnet-depth,灰度图像,黑色表示深部区域,白色表示浅层区域。在这里插入图片描述

  • 骨骼图像Opse:controlnet='fusing/stable-diffusion-v1-5-controlnet-openpose'lllyasviel/sd-controlnet-openpose
    在这里插入图片描述

  • 软边缘线HED:lllyasviel/sd-controlnet-hed,在黑色背景上具有白色柔和边缘的单色图像。

在这里插入图片描述

  • 法线贴图Normal Map:lllyasviel/sd-controlnet-normal
    在这里插入图片描述
  • M-LSD 线:lllyasviel/sd-controlnet-mlsd,仅由黑色背景上的白色直线组成的单色图像。
    在这里插入图片描述
  • 语义分割seg:lllyasviel/sd-controlnet-seg
    在这里插入图片描述
  • 人类涂鸦scribble:lllyasviel/sd-controlnet-scribble

在这里插入图片描述

StableDiffusionControlNetPipeline

StableDiffusionControlNetPipeline像其他 diffuser pipeline 一样,可以从huggingface加载预训练权重。

在命令行安装必要的库:

# diffusers依赖
pip install -q diffusers==0.14.0 transformers xformers git+https://github.com/huggingface/accelerate.git
# 处理不同condition的依赖
pip install -q opencv-contrib-python
pip install -q controlnet_aux

1. Canny ControlNet

1.1 模型与数据加载

加载图像

from diffusers import StableDiffusionControlNetPipeline
from diffusers.utils import load_image

image = load_image(
    "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)
image

在这里插入图片描述
调用cv2的Canny算法提取edge图像作为condition

import cv2
from PIL import Image
import numpy as np

image = np.array(image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
canny_image

在这里插入图片描述

torch.dtype=torch.float16半精度(half-precision)加载SD(runwaylml/stable-diffusion-v1-5)ControlNet-Cnney(lllyasviel/sd-controlnet-canny),实现更快的 Inference。controlnet='lllyasviel/sd-controlnet-canny'

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch

# 分布加载StableDiffusion和ControlNet 组成 StableDiffusionControlNetPipeline
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)

使用当前最快的Noise Scheduler UniPCMultistepScheduler,减少Inference steps from 50 to 20

from diffusers import UniPCMultistepScheduler

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

使用pipe的enable_model_cpu_offload函数实现GPU的自动加载管理,无需手动to("cuda"):因为在推理过程中,模型(如SD)需要多个按顺序运行的模型组件。 在使用 ControlNet 进行 Stable Diffusion 的情况下,我们首先使用 CLIP 文本编码器,然后是 unet 和 controlnet,然后是 VAE 解码器,最后运行safechecker。 大多数组件在推理过程中只运行一次,因此不需要一直占用 GPU 内存。通过启用智能的enable_model_cpu_offload我们确保 每个组件只在需要时加载到 GPU 中,这样我们就可以显着节省内存消耗。

pipe.enable_model_cpu_offload()

利用 FlashAttention/xformers 注意力层加速(如果没有配置xformers就跳过)

pipe.enable_xformers_memory_efficient_attention()
1.2 模型推理

分别测试加和不加prompt后缀的生成结果:

  • prompt 的文本提示应该尽可能地清晰、具体、简洁地描述想要生成的图像,避免模糊、冗长、矛盾的表述。文本提示可以从主体描述、环境氛围、艺术类别、艺术风格、材质、构图、视角、光照、色调等等方面来解构,也可以使用表情符号、角色名、场景名等来增加表现力。
  • prompt 的后缀参数可以影响图像生成的质量、分辨率等。不同参数之间要用,隔开。
normal_image = pipe(
    ["Trump"],
    canny_image,
    negative_prompt=["monochrome, lowres, bad anatomy, worst quality, low quality"],
    generator=generator[0],
    num_inference_steps=20,
)

在这里插入图片描述

# positive prompt后缀
good_image = pipe(
    ["Trump, best quality, extremely detailed"],
    canny_image,
    negative_prompt=["monochrome, lowres, bad anatomy, worst quality, low quality"],
    generator=generator[0],
    num_inference_steps=20,
)

在这里插入图片描述

1.3 DreamBooth微调

还可以用DreamBooth来 fine-tune ControlNet模型 :加载StableDiffusionPipeline(在Mr Potato Head人物subject上DreamBooth微调过的StableDiffusion)和ControlNet(前面一模一样的canny的controlnet)组成StableDiffusionControlNetPipeline
Mr Potato Head

model_id = "sd-dreambooth-library/mr-potato-head"
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    model_id,  # StableDiffusionPipeline的model_id
    controlnet=controlnet,  # ControlNetModel实例
    torch_dtype=torch.float16,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()  # 如果没有装xformers注释掉即可

推理

generator = torch.manual_seed(2)
prompt = "a photo of sks mr potato head, best quality, extremely detailed"
output = pipe(
    prompt,
    canny_image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
    generator=generator,
    num_inference_steps=20,
)

在这里插入图片描述

2. Pose ControlNet

2.1 数据和模型加载

加载瑜伽数据

urls = "yoga1.jpeg", "yoga2.jpeg", "yoga3.jpeg", "yoga4.jpeg"
imgs = [
    load_image("https://hf.co/datasets/YiYiXu/controlnet-testing/resolve/main/" + url)
    for url in urls
]

image_grid(imgs, 2, 2)

在这里插入图片描述

调用controlnet_auxOpenposeDetector提取图片中的pose

from controlnet_aux import OpenposeDetector

model = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")

poses = [model(img) for img in imgs]
image_grid(poses, 2, 2)

在这里插入图片描述
加载StableDiffusionControlNetPipeline,controlnet ='fusing/stable-diffusion-v1-5-controlnet-openpose'

controlnet = ControlNetModel.from_pretrained(
    "fusing/stable-diffusion-v1-5-controlnet-openpose", torch_dtype=torch.float16
)

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    model_id,
    controlnet=controlnet,
    torch_dtype=torch.float16,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()  # 如果没有装xformers注释掉即可
2.2 模型推理
generator = [torch.Generator(device="cpu").manual_seed(2) for i in range(4)]
prompt = "super-hero character, best quality, extremely detailed"
output = pipe(
    [prompt] * 4,
    poses,
    negative_prompt=["monochrome, lowres, bad anatomy, worst quality, low quality"] * 4,
    generator=generator,
    num_inference_steps=20,
)
image_grid(output.images, 2, 2)

在这里插入图片描述

ControlNet 源码解析

在这里插入图片描述

整个StableDiffusionControlNetPipeline包含

		vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        image_encoder: CLIPVisionModelWithProjection = None,

这些组件完整的流程在StableDiffusionControlNetPipeline__call__函数中,其中最关键的2个:左边的UNet2DConditionModel是SD中UNet2D的改进,右边的ControlNetModel用于提取condition_images的特征插入左边的UNet中。

StableDiffusionControlNetPipeline类的call方法

根据promptimageip_adapter_image的引导,生成新的图像。

输入:promptip_adapter_image(reference_image)image(condition_image),输出:output_image

编码 prompt

IP-Adapter的关键设计是分离文本特征和图像特征的交叉注意层的解耦交叉注意机制。实现image_prompt。
在这里插入图片描述
采取classifier_free_guidance策略,构造text_embeddingimage_embedding,用于后续UNet的cross attention作为KV:

  • 对于text_prompt(prompt),用CLIP text encoder编码文本提示后,与negative_prompt_embeds进行concat,得到prompt_embeds
  • 如果有image_prompt(ip_adapter_image),用CLIP image encoder编码参考图片后,与negative_prompt_embeds进行concat,得到image_embeds
		# 3. Encode input prompt
        text_encoder_lora_scale = (
            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
        )
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            self.do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
            clip_skip=self.clip_skip,
        )
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        if ip_adapter_image is not None:
            image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
            if self.do_classifier_free_guidance:
                image_embeds = torch.cat([negative_image_embeds, image_embeds])
预处理condition image、timesteps、latents

Prepare image:进入VAE前将image转化为tensor,设置长宽等。
Prepare timesteps:设置timestep
Prepare latent variables:如果传入预计算好的latents直接用;否则随机初始 latents tensorshape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)

         # 4. Prepare image
        if isinstance(controlnet, ControlNetModel):
            image = self.prepare_image(
                image=image,
                width=width,
                height=height,
                batch_size=batch_size * num_images_per_prompt,
                num_images_per_prompt=num_images_per_prompt,
                device=device,
                dtype=controlnet.dtype,
                do_classifier_free_guidance=self.do_classifier_free_guidance,
                guess_mode=guess_mode,
            )
            height, width = image.shape[-2:]
        elif isinstance(controlnet, MultiControlNetModel):
            images = []
            for image_ in image:
                image_ = self.prepare_image(
                    image=image_,
                    width=width,
                    height=height,
                    batch_size=batch_size * num_images_per_prompt,
                    num_images_per_prompt=num_images_per_prompt,
                    device=device,
                    dtype=controlnet.dtype,
                    do_classifier_free_guidance=self.do_classifier_free_guidance,
                    guess_mode=guess_mode,
                )

                images.append(image_)

            image = images
            height, width = image[0].shape[-2:]
        else:
            assert False

        # 5. Prepare timesteps
        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
        self._num_timesteps = len(timesteps)

        # 6. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )
循环去噪(controlnet输出与unet结合)
  • 对于ControlNetModel类,不用CFG,输入的 latents 是control_model 只包含1份latents,同理输入的controlnet_prompt_embeds 也去掉了negative prompt。
    • 推理过程只需要输入latents(control_model_input)、text_embedding(controlnet_prompt_embeds)、condition_image(image);而不需要referenece_image(ip_adapter_image)
    • 输出down_block_res_samplesmid_block_res_sample,对应下图中不同分辨率的feature,用于插入主干的UNet中。

因为ControlNetModel类做非CFG,它的输出shape是UNet2DConditionModel类的一半,因此在插入UNet2DConditionModel类之前,需要将shape相同的纯0的featuredown_block_res_samplesmid_block_res_sample进行concat。

  • 对于UNet2DConditionModel类,做CFG,输入的 latents 是latent_model_input 包含2份latents。
    • 推理过程只需要输入latents(latent_model_input)、text_embedding(prompt_embeds)、image_embedding(added_cond_kwargs,即ip_adapter_image)、以及ControlNet的输出(down_block_res_samples和mid_block_res_sample);
    • 输出noise_pred。然后执行CFG,然后去噪 x t → x t − 1 x_t \to x_{t-1} xtxt1

最后使用VAE解码latents得到image。
在这里插入图片描述

		# 8. Denoising loop
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # controlnet(s) inference
                if guess_mode and self.do_classifier_free_guidance:
                    # Infer ControlNet only for the conditional batch.
                    control_model_input = latents
                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)
                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
                else:
                    control_model_input = latent_model_input
                    controlnet_prompt_embeds = prompt_embeds
                    
                down_block_res_samples, mid_block_res_sample = self.controlnet(
                    control_model_input,
                    t,
                    encoder_hidden_states=controlnet_prompt_embeds,
                    controlnet_cond=image,
                    conditioning_scale=cond_scale,
                    guess_mode=guess_mode,
                    return_dict=False,
                )

                if guess_mode and self.do_classifier_free_guidance:
                    # Infered ControlNet only for the conditional batch.
                    # To apply the output of ControlNet to both the unconditional and conditional batches,
                    # add 0 to the unconditional batch to keep it unchanged.
                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    down_block_additional_residuals=down_block_res_samples,
                    mid_block_additional_residual=mid_block_res_sample,
                    added_cond_kwargs=added_cond_kwargs,  # image embedding
                    return_dict=False,
                )[0]

                # perform guidance
                if self.do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

		if not output_type == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]

ControlNetModel类

  • 输入latents(sample),text_embedding(encoder_hidden_states)、condition_image(controlnet_cond)
	def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        controlnet_cond: torch.FloatTensor,
        conditioning_scale: float = 1.0,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guess_mode: bool = False,
        return_dict: bool = True,
    ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
  • 构造 time emebdding( time_emebd + class_emebd + text_emebd ),它首先将timestep转换为time_embedding。如果有classs embedding,则还会将class label转换为class_embedding,并将其添加到time_embedding中。此外,如果配置中指定了其他类型的嵌入(如text_embedding),则还会将text_embedding并添加到time_embedding中。最后,所有这些embedding向量被concat在一起并返回。
 		# 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        t_emb = self.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=sample.dtype)
        emb = self.time_embedding(t_emb, timestep_cond)
        aug_emb = None

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")
            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)
            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb

        if self.config.addition_embed_type is not None:
            if self.config.addition_embed_type == "text":
                aug_emb = self.add_embedding(encoder_hidden_states)

            elif self.config.addition_embed_type == "text_time":
                if "text_embeds" not in added_cond_kwargs:
                    raise ValueError(
                        f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
                    )
                text_embeds = added_cond_kwargs.get("text_embeds")
                if "time_ids" not in added_cond_kwargs:
                    raise ValueError(
                        f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
                    )
                time_ids = added_cond_kwargs.get("time_ids")
                time_embeds = self.add_time_proj(time_ids.flatten())
                time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

                add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
                add_embeds = add_embeds.to(emb.dtype)
                aug_emb = self.add_embedding(add_embeds)

        emb = emb + aug_emb if aug_emb is not None else emb
  • 预处理 condition image,使用ControlNetConditioningEmbedding编码condition image然后与sample加在一起。因为Stable Diffusion使用类似于VQ-GAN的预处理方法,将整个512×512图像数据集转换为更小的64×64“latents”以进行稳定训练。这要求使用ControlNets将基于图像的条件转换为64×64的特征空间,以匹配卷积大小。ControlNetConditioningEmbedding使用一个由四个卷积层组成,其核大小为4×4,步幅为2×2(由ReLU激活,通道数为16、32、64、128,使用高斯权重初始化,并与完整模型一起进行训练),来将condition图像编码为与sample对应的特征。最后一层为zero conv,self.conv_out = zero_module(nn.Conv2d(256, 320, kernel_size=3, padding=1))。
		# 2. pre-process
        sample = self.conv_in(sample)

        controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
        sample = sample + controlnet_cond

zero_module将卷积层权重置零,即为论文当中的零卷积。

# down
output_channel = block_out_channels[0] # 32

controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)  # 权重置零
self.controlnet_down_blocks.append(controlnet_block)
def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module
  • 执行down block:从SD复制过来的DownBlock,执行每层时保留对应的sample(latents)down_block_res_samples
		# 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples
  • 执行mid block:从SD复制过来的MidBlock,不用专门保存sample就是。
        # 4. mid
        if self.mid_block is not None:
            if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
                sample = self.mid_block(
                    sample,
                    emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                )
            else:
                sample = self.mid_block(sample, emb)
  • 使用controlnet_block 中的 zero_conv对每层特征进行处理:将经过ControlNetModeldown_block然后经过zero convolution的结果,全部保存在controlnet_down_block_res_samples当中,再保存到down_block_res_samples中。Middle Block经过zero convolution输出保存在mid_block_res_sample当中。
        # 5. Control net blocks
        controlnet_down_block_res_samples = ()

        for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
            down_block_res_sample = controlnet_block(down_block_res_sample)
            controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)

        down_block_res_samples = controlnet_down_block_res_samples

        mid_block_res_sample = self.controlnet_mid_block(sample)
  • 多condition合并和缩放:对于 MultiControlNetModel 的推理过程,在有多个 condition_image 情况下,down_block_res_samples 以及 mid_block_res_sample 则为所有类型的 condition_image 输出加和
        # 6. scaling
        if guess_mode and not self.config.global_pool_conditions:
            scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device)  # 0.1 to 1.0
            scales = scales * conditioning_scale
            down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
            mid_block_res_sample = mid_block_res_sample * scales[-1]  # last one
        else:
            down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
            mid_block_res_sample = mid_block_res_sample * conditioning_scale

        if self.config.global_pool_conditions:
            down_block_res_samples = [
                torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
            ]
            mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)

        if not return_dict:
            return (down_block_res_samples, mid_block_res_sample)

        return ControlNetOutput(
            down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
        )

UNet2DConditionModel类

将ControlNet与Unet输出进行相加。
CrossAttnDownBlock2D 的实现,controlnet 对残差使用求和处理,而非张量拼接。

  • 输入latents(sample)、text_embedding(encoder_hidden_states)、image_embedding(added_cond_kwargs)、ControlNet的输出特征(down_block_additional_residuals, mid_block_additional_residual)
	def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        : Optional[torch.Tensor] = None,
        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
  • 构造复合time_embedding:对time_embclass_embtext_embimage_emb进行concat
        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        t_emb = self.time_proj(timesteps)

        # `Timesteps` does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=sample.dtype)

        emb = self.time_embedding(t_emb, timestep_cond)
        aug_emb = None

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

                # `Timesteps` does not contain any weights and will always return f32 tensors
                # there might be better ways to encapsulate this.
                class_labels = class_labels.to(dtype=sample.dtype)

            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)

            if self.config.class_embeddings_concat:
                emb = torch.cat([emb, class_emb], dim=-1)
            else:
                emb = emb + class_emb

        if self.config.addition_embed_type == "text":
            aug_emb = self.add_embedding(encoder_hidden_states)
        elif self.config.addition_embed_type == "text_image":
            # Kandinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )

            image_embs = added_cond_kwargs.get("image_embeds")
            text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
            aug_emb = self.add_embedding(text_embs, image_embs)
        elif self.config.addition_embed_type == "text_time":
            # SDXL - style
            if "text_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
                )
            text_embeds = added_cond_kwargs.get("text_embeds")
            if "time_ids" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
                )
            time_ids = added_cond_kwargs.get("time_ids")
            time_embeds = self.add_time_proj(time_ids.flatten())
            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
            add_embeds = add_embeds.to(emb.dtype)
            aug_emb = self.add_embedding(add_embeds)
        elif self.config.addition_embed_type == "image":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            aug_emb = self.add_embedding(image_embs)
        elif self.config.addition_embed_type == "image_hint":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            hint = added_cond_kwargs.get("hint")
            aug_emb, hint = self.add_embedding(image_embs, hint)
            sample = torch.cat([sample, hint], dim=1)

        emb = emb + aug_emb if aug_emb is not None else emb

        if self.time_embed_act is not None:
            emb = self.time_embed_act(emb)
  • 构造KV(encoder_hidden_states):可以选择使用text_projtext_image_projimage_projip_image_proj四种模式,使用text_prompt(text_embedding)和reference_image(image_embedding)来构造KV。
		if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
            # Kadinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )

            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )
            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(image_embeds)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )
            image_embeds = added_cond_kwargs.get("image_embeds")
            image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
            encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
  • 处理latents
 		# 2. pre-process
        sample = self.conv_in(sample)
  • 执行down block:正常保存UNet的down_block_res_samples,最后将每层UNet的feature和ControlNetModel的feature直接相加而不是concat!!
		# 3. down
        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None

        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    encoder_attention_mask=encoder_attention_mask,
                    **additional_residuals,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
            down_block_res_samples += res_samples


        if is_controlnet:
            new_down_block_res_samples = ()
            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
                down_block_res_sample = down_block_res_sample + down_block_additional_residual
                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
            down_block_res_samples = new_down_block_res_samples
  • 执行mid block:也是一样,最后UNet的feature和ControlNetModel的feature直接相加而不是concat!!
		# 4. mid
        if self.mid_block is not None:
            if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
                sample = self.mid_block(
                    sample,
                    emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    encoder_attention_mask=encoder_attention_mask,
                )
            else:
                sample = self.mid_block(sample, emb)

            # To support T2I-Adapter-XL
            if (
                is_adapter
                and len(down_intrablock_additional_residuals) > 0
                and sample.shape == down_intrablock_additional_residuals[0].shape
            ):
                sample += down_intrablock_additional_residuals.pop(0)

        if is_controlnet:
            sample = sample + mid_block_additional_residual
  • 执行up block:正常SD的up block
		# 5. up
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    upsample_size=upsample_size,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                )
            else:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    upsample_size=upsample_size,
                    scale=lora_scale,
                )
  • 后处理:最后过一些conv和norm就输出pred noise
# 6. post-process
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
        sample = self.conv_out(sample)
  • 13
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 8
    评论
ControlNet是一种用于工业自动化系统中的数据通信协议,它可以实现实时的控制和监测。控制网络可以用于连接不同的控制设备,如PLCs、DCS和HMI等,使它们可以相互通信和交换数据。 ControlNet的源码解析是指对ControlNet协议的源代码进行分析和理解。ControlNet的源码是用于实现ControlNet协议功能的程序代码,在解析源码时,我们可以了解到ControlNet协议的工作原理、通信机制和数据传输过程等。 ControlNet源码解析主要包括以下几个方面的内容: 1. 协议栈分析:ControlNet是一种分层协议,源码解析可以帮助我们理解ControlNet的协议栈结构和各层之间的关系。协议栈分析包括物理层、数据链路层、网络层和应用层等的功能实现和交互过程。 2. 数据帧解析ControlNet使用数据帧进行数据传输,源码解析可以帮助我们理解数据帧的格式和解析过程。数据帧解析包括数据帧的头部信息、数据字段和校验等的解析过程。 3. 通信机制分析:ControlNet使用广播和点对点通信机制,在源码解析过程中,可以了解到ControlNet的通信机制实现方式、广播数据和点对点数据的传输规则等。 4. 错误检测与容错处理:ControlNet源码解析还可以了解到错误检测和容错处理的实现方式。包括校验、重传机制和错误处理等的源码实现过程。 通过对ControlNet源码的解析,我们可以更深入地了解ControlNet协议的工作原理和机制,对ControlNet的使用和应用具有更深刻的理解。同时,源码解析还可以帮助我们在实际应用中进行定制化和优化,以满足特定的工业控制需求。总的来说,ControlNet源码解析对于学习和应用ControlNet协议具有重要意义。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuezero_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值