DiffBIR/scripts

文章介绍了使用PyTorchLightning进行图像超分辨率任务的脚本,涉及模型训练、权重管理和样本数据生成,展示了深度学习项目开发流程的关键组件。
摘要由CSDN通过智能技术生成

 ## DiffBIR/scripts/inference_stage1.py

import sys
sys.path.append(".")
import os
from argparse import ArgumentParser, Namespace

import pytorch_lightning as pl
from omegaconf import OmegaConf
import torch
from PIL import Image
import numpy as np
from tqdm import tqdm

from utils.image import auto_resize, pad
from utils.common import load_state_dict, instantiate_from_config
from utils.file import list_image_files, get_file_name_parts


def parse_args() -> Namespace:
    parser = ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--ckpt", type=str, required=True)
    parser.add_argument("--input", type=str, required=True)
    parser.add_argument("--sr_scale", type=float, default=1)
    parser.add_argument("--image_size", type=int, default=512)
    parser.add_argument("--show_lq", action="store_true")
    parser.add_argument("--resize_back", action="store_true")
    parser.add_argument("--output", type=str, required=True)
    parser.add_argument("--skip_if_exist", action="store_true")
    parser.add_argument("--seed", type=int, default=231)
    return parser.parse_args()


@torch.no_grad()
def main():
    args = parse_args()
    pl.seed_everything(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    model: pl.LightningModule = instantiate_from_config(OmegaConf.load(args.config))
    load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
    model.freeze()
    model.to(device)
    
    assert os.path.isdir(args.input)
    
    pbar = tqdm(list_image_files(args.input, follow_links=True))
    for file_path in pbar:
        pbar.set_description(file_path)
        save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
        parent_path, stem, _ = get_file_name_parts(save_path)
        save_path = os.path.join(parent_path, f"{stem}.png")
        if os.path.exists(save_path):
            if args.skip_if_exist:
                print(f"skip {save_path}")
                continue
            else:
                raise RuntimeError(f"{save_path} already exist")
        os.makedirs(parent_path, exist_ok=True)
        
        # load low-quality image and resize
        lq = Image.open(file_path).convert("RGB")
        if args.sr_scale != 1:
            lq = lq.resize(
                tuple(int(x * args.sr_scale) for x in lq.size), Image.BICUBIC
            )
        lq_resized = auto_resize(lq, args.image_size)
        # padding
        x = pad(np.array(lq_resized), scale=64)

        x = torch.tensor(x, dtype=torch.float32, device=device) / 255.0
        x = x.permute(2, 0, 1).unsqueeze(0).contiguous()
        try:
            pred = model(x).detach().squeeze(0).permute(1, 2, 0) * 255
            pred = pred.clamp(0, 255).to(torch.uint8).cpu().numpy()
        except RuntimeError as e:
            print(f"inference failed, error: {e}")
            continue
        
        # remove padding
        pred = pred[:lq_resized.height, :lq_resized.width, :]
        if args.show_lq:
            if args.resize_back:
                lq = np.array(lq)
                if lq_resized.size != lq.size:
                    pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
            else:
                lq = np.array(lq_resized)
            final_image = Image.fromarray(np.concatenate([lq, pred], axis=1))
        else:
            if args.resize_back and lq_resized.size != lq.size:
                final_image = Image.fromarray(pred).resize(lq.size, Image.LANCZOS)
            else:
                final_image = Image.fromarray(pred)
        final_image.save(save_path)


if __name__ == "__main__":
    main()

使用PyTorch Lightning和其他库来处理图像超分辨率(Super-Resolution, SR)任务的脚本。它从低分辨率(Low-Quality, LQ)图像生成高分辨率(High-Quality, HQ)图像。下面是代码的主要部分和它们的作用解释:

  1. 导入库

    1. 使用了多个Python库,包括sysos用于系统级别的操作,argparse用于解析命令行参数,pytorch_lightning(简称pl)用于机器学习任务的简化接口,以及PILnumpy等用于图像处理。
  2. 参数解析(parse_args函数):

    1. 定义了命令行接口(CLI),允许用户指定配置文件、检查点文件(模型权重)、输入文件夹(包含低分辨率图像)、输出文件夹和其他一些选项,例如是否回复到原始大小、是否跳过已存在的输出等。
  3. 主函数(main):

    1. 解析命令行参数。
    2. 设置随机种子以确保结果的可复现性。
    3. 加载模型和权重,将模型设置为评估模式,并将其移动到适当的设备(GPU或CPU)。
    4. 验证输入路径是一个目录。
    5. 遍历目录中的所有图像文件,对每个文件进行处理。
      1. 调整图像大小和/或应用填充以满足模型输入要求。
      2. 将处理过的图像转换为PyTorch张量,然后送入模型进行超分辨率处理。
      3. 将模型输出转换回PIL图像,保存到指定的输出路径。
  4. 图像处理步骤

    1. 加载图像并可能进行缩放。
    2. 应用自动调整大小和填充,以确保图像大小符合模型要求。
    3. 将图像数据转换为模型需要的格式(归一化,重排维度)。
    4. 进行模型推理以获得高分辨率图像。
    5. 根据用户的选项,可能将结果图像与原始低分辨率图像并排显示,或者只显示高分辨率图像。
    6. 保存最终图像到指定路径。
  5. 错误处理

    1. 如果指定了--skip_if_exist标志,当输出文件已存在时,该文件将被跳过。
    2. 如果遇到运行时错误(如模型推理失败),将打印错误信息并继续处理下一张图像。
  6. 模型和权重

    1. 使用instantiate_from_config函数从配置文件实例化模型。
    2. 使用load_state_dict函数加载检查点文件中的模型权重。
  7. 配置和检查点文件

    1. 用户必须提供一个配置文件和一个检查点文件。配置文件定义了模型结构和可能的其他参数。检查点文件包含了预训练模型的权重。

这段代码为处理图像超分辨率任务提供了一个完整的流程,包括数据预处理、模型推理和结果后处理。

##DiffBIR/scripts /make_stage2_init_weight.py

import sys
sys.path.append(".")
from argparse import ArgumentParser
from typing import Dict

import torch
from omegaconf import OmegaConf

from utils.common import instantiate_from_config


def load_weight(weight_path: str) -> Dict[str, torch.Tensor]:
    weight = torch.load(weight_path)
    if "state_dict" in weight:
        weight = weight["state_dict"]

    pure_weight = {}
    for key, val in weight.items():
        if key.startswith("module."):
            key = key[len("module."):]
        pure_weight[key] = val

    return pure_weight

parser = ArgumentParser()
parser.add_argument("--cldm_config", type=str, required=True)
parser.add_argument("--sd_weight", type=str, required=True)
parser.add_argument("--swinir_weight", type=str, required=True)
parser.add_argument("--output", type=str, required=True)
args = parser.parse_args()

model = instantiate_from_config(OmegaConf.load(args.cldm_config))

sd_weights = load_weight(args.sd_weight)
swinir_weights = load_weight(args.swinir_weight)
scratch_weights = model.state_dict()

init_weights = {}
for weight_name in scratch_weights.keys():
    # find target pretrained weights for this weight
    if weight_name.startswith("control_"):
        suffix = weight_name[len("control_"):]
        target_name = f"model.diffusion_{suffix}"
        target_model_weights = sd_weights
    elif weight_name.startswith("preprocess_model."):
        suffix = weight_name[len("preprocess_model."):]
        target_name = suffix
        target_model_weights = swinir_weights
    elif weight_name.startswith("cond_encoder."):
        suffix = weight_name[len("cond_encoder."):]
        target_name = F"first_stage_model.{suffix}"
        target_model_weights = sd_weights
    else:
        target_name = weight_name
        target_model_weights = sd_weights
    
    # if target weight exist in pretrained model
    print(f"copy weights: {target_name} -> {weight_name}")
    if target_name in target_model_weights:
        # get pretrained weight
        target_weight = target_model_weights[target_name]
        target_shape = target_weight.shape
        model_shape = scratch_weights[weight_name].shape
        # if pretrained weight has the same shape with model weight, we make a copy
        if model_shape == target_shape:
            init_weights[weight_name] = target_weight.clone()
        # else we copy pretrained weight with additional channels initialized to zero
        else:
            newly_added_channels = model_shape[1] - target_shape[1]
            oc, _, h, w = target_shape
            zero_weight = torch.zeros((oc, newly_added_channels, h, w)).type_as(target_weight)
            init_weights[weight_name] = torch.cat((target_weight.clone(), zero_weight), dim=1)
            print(f"add zero weight to {target_name} in pretrained weights, newly added channels = {newly_added_channels}")
    else:
        init_weights[weight_name] = scratch_weights[weight_name].clone()
        print(f"These weights are newly added: {weight_name}")

model.load_state_dict(init_weights, strict=True)
torch.save(model.state_dict(), args.output)
print("Done.")

这段代码是一个Python脚本,主要用于初始化一个深度学习模型的权重。它结合了两个预训练模型的权重,并处理了不同部分的权重初始化。这个脚本的运行流程和主要组成部分可以这样概括:

  1. 导入所需的库:除了标准库sys之外,还导入了用于命令行参数解析的argparsetorch用于处理深度学习模型和操作,OmegaConf用于处理配置文件,以及脚本自定义的instantiate_from_config函数用于根据配置文件实例化模型。

  2. 定义权重加载函数 (load_weight):

    • 加载指定路径的权重文件。
    • 如果权重文件中包含"state_dict"键,则使用其值;这通常是在模型训练后保存的标准做法。
    • 清理权重键名,移除前缀"module.",这是由于在使用DataParallel等封装模块时,保存的权重键名会自动添加这个前缀。
    • 返回处理后的权重字典。
  3. 解析命令行参数:允许用户指定两个预训练权重文件(--sd_weight--swinir_weight)、一个配置文件(--cldm_config),以及输出文件的路径(--output)。

  4. 模型和权重处理

    • 使用配置文件实例化模型。
    • 加载并处理两组预训练权重。
    • 获取模型的初始权重状态。
  5. 权重初始化逻辑

    • 对于模型中的每个权重,判断它应该使用哪组预训练权重(基于权重名称的前缀)。
    • 如果找到对应的预训练权重,并且形状相匹配,则直接复制权重。
    • 如果形状不匹配,尤其是在通道数方面,会在必要的维度上添加零权重,以保证权重的形状一致性。
    • 如果模型中的权重在预训练权重中没有找到对应项,则保留模型的初始权重。
  6. 模型权重加载:使用处理好的权重初始化模型。

  7. 保存模型权重:将初始化后的模型权重保存到指定的输出文件。

  8. 脚本完成运行:打印"Done."表示脚本成功完成运行。

这个脚本在深度学习研究和应用中很有用,特别是在需要结合不同来源的预训练权重来初始化模型时。这种方法可以加速模型的收敛,提高模型在特定任务上的表现。

## DiffBIR/scripts /sample_dataset.py

import sys
sys.path.append(".")
from argparse import ArgumentParser
import os
from typing import Any

from omegaconf import OmegaConf
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
import pytorch_lightning as pl

from utils.common import instantiate_from_config


def wrap_dataloader(data_loader: DataLoader) -> Any:
    while True:
        yield from data_loader


pl.seed_everything(231, workers=True)

parser = ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--sample_size", type=int, default=128)
parser.add_argument("--show_gt", action="store_true")
parser.add_argument("--output", type=str, required=True)
args = parser.parse_args()

config = OmegaConf.load(args.config)
dataset = instantiate_from_config(config.dataset)
transform = instantiate_from_config(config.batch_transform)
data_loader = wrap_dataloader(DataLoader(dataset, batch_size=1, shuffle=True))

cnt = 0
os.makedirs(args.output, exist_ok=True)

for batch in data_loader:
    batch = transform(batch)
    for hq, lq in zip(batch["jpg"], batch["hint"]):
        hq = ((hq + 1) * 127.5).numpy().clip(0, 255).astype(np.uint8)
        lq = (lq * 255.0).numpy().clip(0, 255).astype(np.uint8)
        if args.show_gt:
            Image.fromarray(np.concatenate([hq, lq], axis=1)).save(os.path.join(args.output, f"{cnt}.png"))
        else:
            Image.fromarray(lq).save(os.path.join(args.output, f"{cnt}.png"))
        cnt += 1
        if cnt >= args.sample_size:
            break
    if cnt >= args.sample_size:
        break

这段代码是一个使用PyTorch Lightning和其他相关库来生成图像样本的脚本。它主要用于从指定的数据集中采样图像,并可选地显示它们的高质量(Ground Truth, GT)版本,然后将这些图像保存到磁盘上。这是代码的详细解析:

  1. 导入库:导入了所需的Python库,包括系统操作、数据加载、图像处理和PyTorch Lightning等。

  2. 函数定义 (wrap_dataloader):

    • 定义了一个生成器函数,这个函数接受一个DataLoader对象作为输入,并在无限循环中迭代返回数据批次。这使得可以持续地从数据集中获取数据,直到满足一定的条件(如采样足够数量的样本)。
  3. 设置随机种子:通过pl.seed_everything固定随机种子以确保实验的可重复性。

  4. 命令行参数解析

    • 允许用户指定配置文件、样本大小、是否显示高质量版本的图像,以及输出目录。
  5. 加载配置文件和实例化数据集

    • 使用OmegaConf.load加载配置文件,该文件应指定如何实例化数据集和批处理转换。
    • 使用自定义的instantiate_from_config函数根据配置实例化数据集和批处理转换。
  6. 数据加载器和数据处理

    • 使用wrap_dataloader函数包装一个DataLoader实例,以便连续地从数据集中获取数据。
    • 设置一个计数器来跟踪已生成的样本数量。
    • 确保输出目录存在。
  7. 生成和保存图像样本

    • 通过循环获取数据批次,对每个批次应用批处理转换。
    • 对于每个批次中的图像,将它们从张量转换为NumPy数组,并调整值的范围,使其适合保存为PNG图像。
    • 如果--show_gt标志被设置,则将高质量图像(hq)和低质量图像(lq)并排保存;否则,仅保存低质量图像(lq)。
    • 当生成的样本数量达到用户指定的--sample_size时,停止生成。
  8. 输出

    • 生成的图像被保存在用户指定的输出目录中,文件名为连续的数字(从0开始),格式为PNG。

这个脚本可以用于数据探索、模型输入的可视化,或生成数据集的低质量和高质量图像对,用于机器学习模型的训练和评估。

这三段代码片段展示了如何设置和使用一个深度学习工作流,包括模型训练、模型初始化(权重加载)、以及生成和保存数据集样本。这些代码片段联动的方式主要通过文件和配置的共享来实现,以下是它们联动的概述:

1. 配置文件和命令行参数

所有三个脚本都依赖于命令行参数和/或配置文件来指定输入、输出和操作参数。配置文件(使用OmegaConf加载)提供了一种灵活的方式来定义模型、数据集以及其他相关设置。

  • 模型训练和权重加载脚本 需要指定模型配置和权重路径,这些可以在训练结束后生成,并用于初始化模型权重。
  • 数据样本生成脚本 需要数据集配置,这些配置定义了数据的来源、如何变换等,以及如何生成和保存样本。

2. 数据处理和模型训练

  • 模型训练脚本 通常使用一个配置文件来指定模型结构、训练参数和数据处理方式。训练完成后,会生成模型的权重文件。
  • 数据样本生成脚本 可以使用训练过程中定义的同一数据处理配置(或者是一个简化版本),以确保生成的样本与训练时使用的样本具有相同的预处理步骤。

3. 模型权重的初始化和应用

  • 权重加载脚本 使用训练完成后的权重文件来初始化一个模型。这个过程包括从预训练模型加载权重,并根据需要进行调整,以匹配目标模型的结构。
  • 一旦模型被初始化,它就可以用于各种任务,包括进一步的训练、评估或应用于新的数据。

联动流程示例:

  1. 训练模型:首先使用第一个脚本训练模型。模型训练完成后,保存配置文件和模型权重。

  2. 模型权重初始化:然后,使用第二个脚本加载模型权重。这可能是为了进行微调、评估或将模型应用于不同的任务。这个脚本读取训练阶段产生的权重文件,并可能结合其他预训练权重,初始化模型以供后续使用。

  3. 生成和保存数据样本:最后,第三个脚本利用配置文件来实例化和使用数据集。它可以用于生成模型输入的样本并将其保存,例如,为了进行模型评估或可视化。如果这个脚本生成的是模型的输入数据,那么它也可以配合已经通过第二个脚本初始化权重的模型来使用。

整个流程展示了从模型训练到权重初始化,再到数据处理和应用的完整生命周期。通过这种方式,可以灵活地将训练好的模型应用于不同的数据集和任务,同时确保数据处理和模型初始化的一致性和复用性。

  • 6
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值