IC-Light 在神经网络框架中的集成

IC-Light 在神经网络框架中的集成

导入语句

import torch
import torchvision.transforms as transforms
import folder_paths
import os
import types
import numpy as np
import torch.nn.functional as F
from comfy.utils import load_torch_file
from .utils.convert_unet import convert_iclight_unet
from .utils.patches import calculate_weight_adjust_channel
from .utils.image import generate_gradient_image, LightPosition
from nodes import MAX_RESOLUTION
from comfy.model_patcher import ModelPatcher
import model_management

这些导入语句提供了执行数组操作、深度学习模型操作、文件路径处理等功能所需的各种模块和函数。特别是 torchtorchvision 库用于深度学习相关的操作,而 folder_pathsos 用于文件系统操作。

LoadAndApplyICLightUnet 的定义

输入类型定义
@classmethod
def INPUT_TYPES(s):
    return {
        "required": {
            "model": ("MODEL",),
            "model_path": (folder_paths.get_filename_list("unet"), )
        } 
    }

此方法定义了必需的输入类型,包括一个模型对象和一个模型路径,后者使用 folder_paths.get_filename_list("unet") 函数获取UNet模型文件的列表。

加载函数
def load(self, model, model_path):
    type_str = str(type(model.model.model_config).__name__)
    if "SD15" not in type_str:
        raise Exception(f"Attempted to load {type_str} model, IC-Light is only compatible with SD 1.5 models.")

    print("LoadAndApplyICLightUnet: Checking IC-Light Unet path")
    model_full_path = folder_paths.get_full_path("unet", model_path)
    if not os.path.exists(model_full_path):
        raise Exception("Invalid model path")
    else:
        print("LoadAndApplyICLightUnet: Loading IC-Light Unet weights")
        model_clone = model.clone()

        iclight_state_dict = load_torch_file(model_full_path)
        
        print("LoadAndApplyICLightUnet: Attempting to add patches with IC-Light Unet weights")
        try:
            if 'conv_in.weight' in iclight_state_dict:
                iclight_state_dict = convert_iclight_unet(iclight_state_dict)
                in_channels = iclight_state_dict["diffusion_model.input_blocks.0.0.weight"].shape[1]
                for key in iclight_state_dict:
                    model_clone.add_patches({key: (iclight_state_dict[key],)}, 1.0, 1.0)
            else:
                for key in iclight_state_dict:
                    model_clone.add_patches({"diffusion_model." + key: (iclight_state_dict[key],)}, 1.0, 1.0)

                in_channels = iclight_state_dict["input_blocks.0.0.weight"].shape[1]

        except:
            raise Exception("Could not patch model")
        print("LoadAndApplyICLightUnet: Added LoadICLightUnet patches")

        try:
            ModelPatcher.calculate_weight = calculate_weight_adjust_channel(ModelPatcher.calculate_weight)
        except:
            raise Exception("IC-Light: Could not patch calculate_weight")
        # Mimic the existing IP2P class to enable extra_conds
        def bound_extra_conds(self, **kwargs):
             return ICLight.extra_conds(self, **kwargs)
        new_extra_conds = types.MethodType(bound_extra_conds, model_clone.model)
        model_clone.add_object_patch("extra_conds", new_extra_conds)
        

        model_clone.model.model_config.unet_config["in_channels"] = in_channels        

        return (model_clone, )

此函数负责加载并应用IC-Light的UNet模型。它首先检查模型是否符合兼容性(仅限SD 1.5模型),然后加载模型权重,并尝试将这些权重作为补丁应用到克隆的模型上。此外,它还修改了模型的权重计算方法,并添加了额外的条件处理方法以增强功能。

ICLightConditioning 的定义

输入类型定义
@classmethod
def INPUT_TYPES(s):
    return {
        "required": {
            "positive": ("CONDITIONING", ),
            "negative": ("CONDITIONING", ),
            "vae": ("VAE", ),


            "foreground": ("LATENT", ),
            "multiplier": ("FLOAT", {"default": 0.18215, "min": 0.0, "max": 1.0, "step": 0.001}),
        },
        "optional": {
            "opt_background": ("LATENT", ),
        },
    }

定义了IC-Light调节所需的输入参数,包括正负调节条件、VAE模型、前景潜像、乘数以及可选的背景潜像。

编码函数
def encode(self, positive, negative, vae, foreground, multiplier, opt_background=None):
    samples_1 = foreground["samples"]

    if opt_background is not None:
        samples_2 = opt_background["samples"]

        repeats_1 = samples_2.size(0) // samples_1.size(0)
        repeats_2 = samples_1.size(0) // samples_2.size(0)
        if samples_1.shape[1:] != samples_2.shape[1:]:
            samples_2 = comfy.utils.common_upscale(samples_2, samples_1.shape[-1], samples_1.shape[-2], "bilinear", "disabled")

        if repeats_1 > 1:
            samples_1 = samples_1.repeat(repeats_1, 1, 1, 1)
        if repeats_2 > 1:
            samples_2 = samples_2.repeat(repeats_2, 1, 1, 1)

        concat_latent = torch.cat((samples_1, samples_2), dim=1)
    else:
        concat_latent = samples_1

    out_latent = torch.zeros_like(samples_1)

    out = []
    for conditioning in [positive, negative]:
        c = []
        for t in conditioning:
            d = t[1].copy()
            d["concat_latent_image"] = concat_latent * multiplier
            n = [t[0], d]
            c.append(n)
        out.append(c)
    return (out[0], out[1], {"samples": out_latent})

此函数处理传入的正负条件,应用VAE和潜像,并将乘数应用于组合潜像以生成新的调节结果。它支持处理可选的背景潜像,允许更复杂的图像合成操作。

节点类和显示名称的注册

NODE_CLASS_MAPPINGS = {
    "LoadAndApplyICLightUnet": LoadAndApplyICLightUnet,
    "ICLightConditioning": ICLightConditioning,
}
NODE_DISPLAY_NAME_MAPPINGS = {
    "LoadAndApplyICLightUnet": "Load And Apply IC-Light",
    "ICLightConditioning": "IC-Light Conditioning",
}

这些代码行将上述定义的类注册为可用的节点类型,并设置它们的显示名称,使得这些功能可以在更大的系统或框架中通过指定的名称被调用。

  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值