用Python实现LoadImage类

这段代码定义了一个名为 LoadImage 的类,用于从指定目录中加载图像文件,并将其转换为适合处理的格式。该类还支持生成图像的遮罩,并提供了输入验证和更改检测功能。下面是对这段代码的详细解释:

LoadImage

类方法 INPUT_TYPES
@classmethod
def INPUT_TYPES(s):
    input_dir = folder_paths.get_input_directory()
    files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
    return {"required":
                {"image": (sorted(files), {"image_upload": True})},
            }
  1. 获取输入目录路径:通过 folder_paths.get_input_directory() 获取输入目录路径。
  2. 获取目录中的文件列表:使用 os.listdiros.path.isfile 列出输入目录中的所有文件。
  3. 返回输入类型:定义输入参数 image,其值为输入目录中排序后的文件列表,并允许上传图像。
类属性 CATEGORYRETURN_TYPES
CATEGORY = "image"
RETURN_TYPES = ("IMAGE", "MASK")
  1. CATEGORY:将该类归类为 "image"。
  2. RETURN_TYPES:定义返回类型为 "IMAGE" 和 "MASK"。
方法 load_image
def load_image(self, image):
    image_path = folder_paths.get_annotated_filepath(image)
    
    img = node_helpers.pillow(Image.open, image_path)
    
    output_images = []
    output_masks = []
    w, h = None, None

    excluded_formats = ['MPO']
    
    for i in ImageSequence.Iterator(img):
        i = node_helpers.pillow(ImageOps.exif_transpose, i)

        if i.mode == 'I':
            i = i.point(lambda i: i * (1 / 255))
        image = i.convert("RGB")

        if len(output_images) == 0:
            w = image.size[0]
            h = image.size[1]
        
        if image.size[0] != w or image.size[1] != h:
            continue
        
        image = np.array(image).astype(np.float32) / 255.0
        image = torch.from_numpy(image)[None,]
        if 'A' in i.getbands():
            mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
            mask = 1. - torch.from_numpy(mask)
        else:
            mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
        output_images.append(image)
        output_masks.append(mask.unsqueeze(0))

    if len(output_images) > 1 and img.format not in excluded_formats:
        output_image = torch.cat(output_images, dim=0)
        output_mask = torch.cat(output_masks, dim=0)
    else:
        output_image = output_images[0]
        output_mask = output_masks[0]

    return (output_image, output_mask)
  1. 获取图像路径:通过 folder_paths.get_annotated_filepath(image) 获取图像的完整路径。
  2. 打开图像:使用 node_helpers.pillowImage.open 打开图像文件。
  3. 初始化输出图像和遮罩列表:output_imagesoutput_masks
  4. 获取图像宽度和高度:初始化 whNone
  5. 排除格式:定义 excluded_formats 列表,用于排除特定格式。
  6. 遍历图像序列:使用 ImageSequence.Iterator(img) 遍历图像帧。
    • 转置图像:使用 ImageOps.exif_transpose 进行图像转置。
    • 处理图像模式:如果图像模式为 "I",则将其转换为 0-1 范围。
    • 转换为 RGB:将图像转换为 RGB 模式。
    • 获取图像尺寸:在处理第一帧时,获取图像的宽度和高度。
    • 忽略尺寸不匹配的帧:如果后续帧的尺寸与第一帧不一致,则跳过。
    • 转换为 NumPy 数组并标准化:将图像转换为 NumPy 数组,并将像素值标准化到 0-1 范围。
    • 转换为 PyTorch 张量:将图像转换为 PyTorch 张量。
    • 处理 Alpha 通道:如果图像包含 Alpha 通道,则提取并转换为遮罩。
    • 添加图像和遮罩:将处理后的图像和遮罩添加到列表中。
  7. 合并多帧图像和遮罩:如果图像包含多帧且不在排除格式中,则合并所有帧,否则只使用第一帧。
  8. 返回图像和遮罩:返回处理后的图像和遮罩。
类方法 IS_CHANGED
@classmethod
def IS_CHANGED(s, image):
    image_path = folder_paths.get_annotated_filepath(image)
    m = hashlib.sha256()
    with open(image_path, 'rb') as f:
        m.update(f.read())
    return m.digest().hex()
  1. 获取图像路径:通过 folder_paths.get_annotated_filepath(image) 获取图像的完整路径。
  2. 计算哈希值:使用 hashlib.sha256() 计算图像文件的哈希值,以检测文件是否发生变化。
  3. 返回哈希值:返回计算得到的哈希值。
类方法 VALIDATE_INPUTS
@classmethod
def VALIDATE_INPUTS(s, image):
    if not folder_paths.exists_annotated_filepath(image):
        return "Invalid image file: {}".format(image)

    return True
  1. 验证图像文件是否存在:通过 folder_paths.exists_annotated_filepath(image) 检查图像文件是否存在。
  2. 返回验证结果:如果文件不存在,则返回错误信息;否则返回 True

总结

LoadImage 类提供了从指定目录加载图像文件,并将其转换为适合处理的格式(包括生成遮罩)的功能。它还包括输入验证和文件变化检测功能。这个类可以用于图像处理和生成任务,确保输入图像格式正确,并生成必要的遮罩。

  • 7
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值