图像重绘输出

class easyControlnet:
    # 定义一个名为 easyControlnet 的类

    def __init__(self):
        # 初始化函数
        pass

    def apply(self, control_net_name, image, positive, negative, strength, start_percent=0, end_percent=1, control_net=None, scale_soft_weights=1, mask=None, easyCache=None, use_cache=True):
        # 定义 apply 方法,处理 ControlNet 应用逻辑,接受多个参数,包括控制网络名称、图像、正向和负向提示、强度等
        if strength == 0:
            return (positive, negative)
            # 如果强度为 0,直接返回正向和负向提示

        if control_net is None:
            control_net = easyCache.load_controlnet(control_net_name, scale_soft_weights, use_cache)
            # 如果控制网络为空,从缓存中加载控制网络

        if mask is not None:
            mask = mask.to(self.device)
            # 如果提供了掩码,将其转移到设备上

        if mask is not None and len(mask.shape) < 3:
            mask = mask.unsqueeze(0)
            # 如果掩码维度小于 3,增加一个维度

        control_hint = image.movedim(-1, 1)
        # 将图像的最后一个维度移动到第二个维度,生成控制提示

        is_cond = True
        if negative is None:
            p = []
            for t in positive:
                n = [t[0], t[1].copy()]
                c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
                # 如果没有负向提示,处理正向提示
                # 复制正向提示,并设置控制网络的提示信息、强度和百分比范围
                if 'control' in t[1]:
                    c_net.set_previous_controlnet(t[1]['control'])
                    # 如果正向提示中有之前的控制网络,将其设置为前一个控制网络
                n[1]['control'] = c_net
                # 将当前控制网络设置为提示中的控制网络
                n[1]['control_apply_to_uncond'] = True
                # 设置控制网络应用到无条件
                if mask is not None:
                    n[1]['mask'] = mask
                    n[1]['set_area_to_bounds'] = False
                    # 如果有掩码,将掩码添加到提示中,并设置不限制区域
                p.append(n)
            positive = p
        else:
            cnets = {}
            out = []
            for conditioning in [positive, negative]:
                c = []
                for t in conditioning:
                    d = t[1].copy()
                    # 如果有负向提示,分别处理正向和负向提示
                    # 复制每个提示的字典
                    prev_cnet = d.get('control', None)
                    if prev_cnet in cnets:
                        c_net = cnets[prev_cnet]
                    else:
                        c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
                        c_net.set_previous_controlnet(prev_cnet)
                        cnets[prev_cnet] = c_net
                        # 如果提示中有之前的控制网络,检查缓存中是否已有该网络,如果有则使用缓存中的网络,否则创建新的控制网络,并设置提示信息、强度和百分比范围
                    d['control'] = c_net
                    d['control_apply_to_uncond'] = False
                    # 设置当前控制网络并应用到无条件
                    if mask is not None:
                        d['mask'] = mask
                        d['set_area_to_bounds'] = False
                        # 如果有掩码,将掩码添加到提示中,并设置不限制区域
                    n = [t[0], d]
                    c.append(n)
                out.append(c)
            positive = out[0]
            negative = out[1]

        return (positive, negative)
        # 返回处理后的正向和负向提示

主要步骤和功能

  1. 定义类和初始化

    • easyControlnet 类主要用于应用 ControlNet 进行图像处理。
    • __init__ 方法为空,表示没有特定的初始化逻辑。
  2. apply 方法

    • apply 方法是这个类的核心功能,接受多个参数来设置和应用 ControlNet。
    • 如果 strength 为 0,直接返回原始的 positivenegative 提示。
  3. 加载控制网络

    • 如果 control_net 为空,从缓存中加载指定的控制网络。
  4. 处理掩码

    • 如果提供了掩码,将其转移到指定的设备上,并确保掩码的维度合适。
  5. 生成控制提示

    • 通过调整图像维度生成控制提示 control_hint
  6. 处理正向提示

    • 如果没有负向提示,逐一处理正向提示,设置控制网络和相关参数。
    • 如果有掩码,将掩码应用到每个提示中。
  7. 处理正向和负向提示

    • 如果有负向提示,分别处理正向和负向提示,设置控制网络和相关参数,并进行缓存管理。
  8. 返回结果

    • 返回处理后的正向和负向提示。

这个类的主要作用是根据给定的图像和提示信息应用 ControlNet,以实现特定的图像处理任务。它可以与大模型(如预训练的 ControlNet 模型)一起使用,以增强和调整图像生成过程。

class IPAdapter(nn.Module):
    # 定义一个名为 IPAdapter 的神经网络模块,继承自 nn.Module

    def __init__(self, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4, is_sdxl=False, is_plus=False, is_full=False, is_faceid=False, is_portrait_unnorm=False):
        # 初始化函数,设置默认参数和模型配置
        super().__init__()

        self.clip_embeddings_dim = clip_embeddings_dim
        # 设置 CLIP 嵌入维度
        self.cross_attention_dim = cross_attention_dim
        # 设置交叉注意力维度
        self.output_cross_attention_dim = output_cross_attention_dim
        # 设置输出交叉注意力维度
        self.clip_extra_context_tokens = clip_extra_context_tokens
        # 设置 CLIP 额外上下文标记数量
        self.is_sdxl = is_sdxl
        # 是否使用 SDXL 模型
        self.is_full = is_full
        # 是否使用全功能模型
        self.is_plus = is_plus
        # 是否使用增强模型
        self.is_portrait_unnorm = is_portrait_unnorm
        # 是否使用未归一化的肖像模型

        if is_faceid and not is_portrait_unnorm:
            self.image_proj_model = self.init_proj_faceid()
            # 如果是人脸识别且不是未归一化的肖像模型,初始化人脸识别投影模型
        elif is_full:
            self.image_proj_model = self.init_proj_full()
            # 如果是全功能模型,初始化全功能投影模型
        elif is_plus or is_portrait_unnorm:
            self.image_proj_model = self.init_proj_plus()
            # 如果是增强模型或未归一化的肖像模型,初始化增强投影模型
        else:
            self.image_proj_model = self.init_proj()
            # 否则,初始化基本投影模型

        self.image_proj_model.load_state_dict(ipadapter_model["image_proj"])
        # 加载图像投影模型的状态字典
        self.ip_layers = To_KV(ipadapter_model["ip_adapter"])
        # 初始化 IP 层

    def init_proj(self):
        # 初始化基本图像投影模型
        image_proj_model = ImageProjModel(
            cross_attention_dim=self.cross_attention_dim,
            clip_embeddings_dim=self.clip_embeddings_dim,
            clip_extra_context_tokens=self.clip_extra_context_tokens
        )
        return image_proj_model

    def init_proj_plus(self):
        # 初始化增强型图像投影模型
        image_proj_model = Resampler(
            dim=self.cross_attention_dim,
            depth=4,
            dim_head=64,
            heads=20 if self.is_sdxl else 12,
            num_queries=self.clip_extra_context_tokens,
            embedding_dim=self.clip_embeddings_dim,
            output_dim=self.output_cross_attention_dim,
            ff_mult=4
        )
        return image_proj_model

    def init_proj_full(self):
        # 初始化全功能图像投影模型
        image_proj_model = MLPProjModel(
            cross_attention_dim=self.cross_attention_dim,
            clip_embeddings_dim=self.clip_embeddings_dim
        )
        return image_proj_model

    def init_proj_faceid(self):
        # 初始化用于人脸识别的图像投影模型
        if self.is_plus:
            image_proj_model = ProjModelFaceIdPlus(
                cross_attention_dim=self.cross_attention_dim,
                id_embeddings_dim=512,
                clip_embeddings_dim=self.clip_embeddings_dim, # 1280,
                num_tokens=self.clip_extra_context_tokens, # 4,
            )
        else:
            image_proj_model = MLPProjModelFaceId(
                cross_attention_dim=self.cross_attention_dim,
                id_embeddings_dim=512,
                num_tokens=self.clip_extra_context_tokens,
            )
        return image_proj_model

    @torch.inference_mode()
    def get_image_embeds(self, clip_embed, clip_embed_zeroed, batch_size):
        # 获取图像嵌入,在推理模式下运行
        torch_device = model_management.get_torch_device()
        # 获取计算设备(CPU 或 GPU)
        intermediate_device = model_management.intermediate_device()
        # 获取中间设备(通常用于中间结果的存储)

        if batch_size == 0:
            batch_size = clip_embed.shape[0]
            intermediate_device = torch_device
            # 如果批量大小为 0,设置为 CLIP 嵌入的批量大小,并将中间设备设置为计算设备
        elif batch_size > clip_embed.shape[0]:
            batch_size = clip_embed.shape[0]
            # 如果批量大小大于 CLIP 嵌入的批量大小,调整批量大小

        clip_embed = torch.split(clip_embed, batch_size, dim=0)
        # 将 CLIP 嵌入按批量大小拆分
        clip_embed_zeroed = torch.split(clip_embed_zeroed, batch_size, dim=0)
        # 将零化的 CLIP 嵌入按批量大小拆分

        image_prompt_embeds = []
        # 初始化图像提示嵌入列表
        uncond_image_prompt_embeds = []
        # 初始化无条件图像提示嵌入列表

        for ce, cez in zip(clip_embed, clip_embed_zeroed):
            image_prompt_embeds.append(self.image_proj_model(ce.to(torch_device)).to(intermediate_device))
            # 对每个 CLIP 嵌入进行投影,并移动到中间设备
            uncond_image_prompt_embeds.append(self.image_proj_model(cez.to(torch_device)).to(intermediate_device))
            # 对每个零化的 CLIP 嵌入进行投影,并移动到中间设备
        
        del clip_embed, clip_embed_zeroed
        # 删除原始 CLIP 嵌入和零化的 CLIP 嵌入

        image_prompt_embeds = torch.cat(image_prompt_embeds, dim=0)
        # 将图像提示嵌入按批量大小连接
        uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds, dim=0)
        # 将无条件图像提示嵌入按批量大小连接
        
        torch.cuda.empty_cache()
        # 清理 GPU 缓存

        #image_prompt_embeds = self.image_proj_model(clip_embed)
        #uncond_image_prompt_embeds = self.image_proj_model(clip_embed_zeroed)
        return image_prompt_embeds, uncond_image_prompt_embeds
        # 返回图像提示嵌入和无条件图像提示嵌入

    @torch.inference_mode()
    def get_image_embeds_faceid_plus(self, face_embed, clip_embed, s_scale, shortcut, batch_size):
        # 获取人脸识别增强型图像嵌入,在推理模式下运行
        torch_device = model_management.get_torch_device()
        # 获取计算设备(CPU 或 GPU)
        intermediate_device = model_management.intermediate_device()
        # 获取中间设备(通常用于中间结果的存储)

        if batch_size == 0:
            batch_size = clip_embed.shape[0]
            intermediate_device = torch_device
            # 如果批量大小为 0,设置为 CLIP 嵌入的批量大小,并将中间设备设置为计算设备
        elif batch_size > clip_embed.shape[0]:
            batch_size = clip_embed.shape[0]
            # 如果批量大小大于 CLIP 嵌入的批量大小,调整批量大小
        
        face_embed_batch = torch.split(face_embed, batch_size, dim=0)
        # 将人脸嵌入按批量大小拆分
        clip_embed_batch = torch.split(clip_embed, batch_size, dim=0)
        # 将 CLIP 嵌入按批量大小拆分

        embeds = []
        # 初始化嵌入列表
        for face_embed, clip_embed in zip(face_embed_batch, clip_embed_batch):
            embeds.append(self.image_proj_model(face_embed.to(torch_device), clip_embed.to(torch_device), scale=s_scale, shortcut=shortcut).to(intermediate_device))
            # 对每个人脸嵌入和 CLIP 嵌入进行投影,并移动到中间设备

        del face_embed_batch, clip_embed_batch
        # 删除拆分的嵌入

        embeds = torch.cat(embeds, dim=0)
        # 将嵌入按批量大小连接
        torch.cuda.empty_cache()
        # 清理 GPU 缓存

        #embeds = self.image_proj_model(face_embed, clip_embed, scale=s_scale, shortcut=shortcut)
        return embeds
        # 返回生成的嵌入

主要步骤和功能

  1. 类定义和初始化

    • IPAdapter 类继承自 nn.Module,用于初始化图像投影模型和相关配置参数。
    • 根据不同配置初始化适合的图像投影模型,并加载模型状态。
  2. 初始化方法

    • init_projinit_proj_plusinit_proj_fullinit_proj_faceid 方法用于初始化不同类型的图像投影模型。
    • 每个方法根据模型类型和配置参数创建并返回相应的图像投影模型。
  3. 获取图像嵌入

    • get_image_embeds 方法在推理模式下运行,处理并返回图像嵌入。
    • 该方法将输入的 CLIP 嵌入向量进行分批处理,并使用图像投影模型生成图像提示嵌入和无条件图像提示嵌入。
  4. 获取人脸识别增强型图像嵌入

    • get_image_embeds_faceid_plus 方法在推理模式下运行,处理并返回人脸识别增强型图像嵌入。
    • 该方法将输入的人脸嵌入和 CLIP 嵌入向量进行分批处理,并使用图像投影模型生成增强型嵌入。

这个类的主要作用是根据不同的配置和输入,使用相应的图像投影模型生成图像嵌入,适用于图像生成和人脸识别等任务。它可以与大模型(如 CLIP)一起使用,以增强图像处理的效果。

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值