DDIM模型代码解析(一)

目录

预备知识

main.py

解析命令行参数

解析配置文件


预备知识

由于代码中除了一些必要的对模型、数据进行操作的PyTorch函数外,还有一些辅助显示训练等过程有关信息的,或辅助对文件目录进行操作的库。因此,建议读者先对这些库进行了解,试着写一写示例代码,理解库中函数的使用方法后再阅读下面的讲解,这样可以更顺畅。

import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os

main.py

首先对输出的选项进行设定,让输出的内容不按科学计数法模式。

torch.set_printoptions(sci_mode=False)  # 设置为不按照科学计数法表示输出

然后程序进入main()函数中,在main函数中完成了以下任务:

  • 解析命令行参数
  • 解析配置文件
  • 打印相关信息
  • 扩散过程实例化
  • 完成采样 / 测试 / 训练过程

后面我们逐一进行代码分析。

def main():
    args, config = parse_args_and_config()  # 解析命令行参数和配置文件
    logging.info("Writing log file to {}".format(args.log_path))  # 显示日志存储路径信息
    logging.info("Exp instance id = {}".format(os.getpid()))  # 显示进程id信息
    logging.info("Exp comment = {}".format(args.comment))  # 显示实验注释信息

    try:
        runner = Diffusion(args, config)  # 构建扩散运行实例对象
        if args.sample:  # 如果是采样操作,就执行采样函数
            runner.sample()
        elif args.test:  # 如果是测试模型,就执行测试函数
            runner.test()
        else:  # 否则就执行训练函数
            runner.train()
    except Exception:  # 如果报错就输出错误信息日志
        logging.error(traceback.format_exc())

    return 0

解析命令行参数

对命令行参数的解析在parse_args_and_config函数中完成,每一个参数的含义以注释的形式标明,如果有异议欢迎在评论中指出。

def parse_args_and_config():
    parser = argparse.ArgumentParser(description=globals()["__doc__"])

    parser.add_argument(  # config文件路径
        "--config", type=str, required=True, help="Path to the config file"
    )
    parser.add_argument("--seed", type=int, default=1234, help="Random seed")  # 随机种子
    parser.add_argument(  # 用于保存运行相关数据的路径
        "--exp", type=str, default="exp", help="Path for saving running related data."
    )
    parser.add_argument(  # log日志文件夹名称
        "--doc",
        type=str,
        required=True,
        help="A string for documentation purpose. "
        "Will be the name of the log folder.",
    )
    parser.add_argument(  # 实验注释
        "--comment", type=str, default="", help="A string for experiment comment"
    )
    parser.add_argument(  # logging日志的级别: info, debug, warning, critical
        "--verbose",
        type=str,
        default="info",
        help="Verbose level: info | debug | warning | critical",
    )
    parser.add_argument("--test", action="store_true", help="Whether to test the model")  # 是否测试模型
    parser.add_argument(  # 是否从模型产生采样
        "--sample",
        action="store_true",
        help="Whether to produce samples from the model",
    )
    parser.add_argument("--fid", action="store_true")  # FID指标
    parser.add_argument("--interpolation", action="store_true")  # 插值
    parser.add_argument(  # 是否为继续训练
        "--resume_training", action="store_true", help="Whether to resume training"
    )
    parser.add_argument(  # 采样的文件夹名称
        "-i",
        "--image_folder",
        type=str,
        default="images",
        help="The folder name of samples",
    )
    parser.add_argument(  # 无交互
        "--ni",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument("--use_pretrained", action="store_true")  # 使用预训练
    parser.add_argument(  # 采样类型
        "--sample_type",
        type=str,
        default="generalized",
        help="sampling approach (generalized or ddpm_noisy)",
    )
    parser.add_argument(  # 跳跃类型
        "--skip_type",
        type=str,
        default="uniform",
        help="skip according to (uniform or quadratic)",
    )
    parser.add_argument(  # 步数
        "--timesteps", type=int, default=1000, help="number of steps involved"
    )
    parser.add_argument(  # \eta超参数用于控制方差
        "--eta",
        type=float,
        default=0.0,
        help="eta used to control the variances of sigma",
    )
    parser.add_argument("--sequence", action="store_true")  # 是否为序列

    args = parser.parse_args()  # 解析参数
    args.log_path = os.path.join(args.exp, "logs", args.doc)  # log日志路径: exp/logs/$doc$
    
    ...

解析配置文件

解析配置文件的过程也是在parse_args_and_config函数中,args.config应该是bedroom,celeba,church,cifar10中的一个。这样我们可以直接打开文件夹configs中对应数据集的yaml配置文件,此时config为字典类型。经过dict2namespace函数,将字典类型转换为argparse中命名空间的形式。

def parse_args_and_config():
    ...

    # parse config file
    with open(os.path.join("configs", args.config), "r") as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)

    ...

转换函数如下:

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

之后还有一步设定tensorboard日志的路径,可以在训练时用tensorboard查看训练进度信息:

def parse_args_and_config():
    ...

    tb_path = os.path.join(args.exp, "tensorboard", args.doc)  # tensorboard日志路径: exp/tensorboard/$doc$

    ...

之后会执行训练 / 采样 / 测试不同的代码部分:

首先看一下对于训练会执行的代码:

  • 创建log日志文件夹
  • 创建tensorboard日志文件夹
  • 设置logging的logger
def parse_args_and_config():
    ...

    if not args.test and not args.sample:
        if not args.resume_training:
            if os.path.exists(args.log_path):  # 如果log输出路径存在的话
                overwrite = False  # 选择不覆盖
                if args.ni:  # 如果ni为True
                    overwrite = True  # 选择覆盖
                else:
                    response = input("Folder already exists. Overwrite? (Y/N)")  # 询问是否覆盖
                    if response.upper() == "Y":  # 如果Y, 则选择覆盖原有log
                        overwrite = True

                if overwrite:  # 如果选择覆盖
                    shutil.rmtree(args.log_path)  # 删除原有log文件路径
                    shutil.rmtree(tb_path)  # 删除原有tensorboard文件路径
                    os.makedirs(args.log_path)  # 创建新的log文件路径
                    if os.path.exists(tb_path):  # 如果tensorboard文件路径存在, 就删除它
                        shutil.rmtree(tb_path)
                else:  # 如果选择不覆盖, 则提示文件夹存在, 程序停止
                    print("Folder exists. Program halted.")
                    sys.exit(0)
            else:  # 如果log输出路径不存在就创建路径
                os.makedirs(args.log_path)

            with open(os.path.join(args.log_path, "config.yml"), "w") as f:
                yaml.dump(new_config, f, default_flow_style=False)

        new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
        # setup logger
        level = getattr(logging, args.verbose.upper(), None)  # 20 (logging.INFO) 或者其它的级别
        if not isinstance(level, int):  # 如果为None的话就会报错
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()  # 将log在CLI输出的handler
        handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt"))  # 将log在文件输出的handler
        formatter = logging.Formatter(  # 控制log输出格式的formatter
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"  # INFO - __main__ - ... - ....
        )
        handler1.setFormatter(formatter)  # 设置CLI输出handler的格式
        handler2.setFormatter(formatter)  # 设置文件输出handler的格式
        logger = logging.getLogger()  # root logger
        logger.addHandler(handler1)  # 添加CLI输出handler
        logger.addHandler(handler2)  # 添加文件输出handler
        logger.setLevel(level)  # 设定root logger的级别

    ...

然后是采样 / 测试会执行的代码:

  • 设置logging的logger
  • 对于采样,会创建图像文件夹
def parse_args_and_config():
    ...


    else:
        level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(level, int):
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()
        formatter = logging.Formatter(
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
        )
        handler1.setFormatter(formatter)
        logger = logging.getLogger()
        logger.addHandler(handler1)
        logger.setLevel(level)

        if args.sample:  # 如果是采样
            os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True)  # 创建目录: exp/image_samples
            args.image_folder = os.path.join(  # 添加图像文件夹参数: exp/image_samples/$image_folder$
                args.exp, "image_samples", args.image_folder
            )
            if not os.path.exists(args.image_folder):  # 如果图像文件夹不存在就创建一个
                os.makedirs(args.image_folder)
            else:  # 如果图像文件夹存在
                if not (args.fid or args.interpolation):
                    overwrite = False
                    if args.ni:
                        overwrite = True
                    else:
                        response = input(
                            f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
                        )
                        if response.upper() == "Y":
                            overwrite = True

                    if overwrite:  # 如果覆盖, 删除并新建文件夹
                        shutil.rmtree(args.image_folder)
                        os.makedirs(args.image_folder)
                    else:
                        print("Output image folder exists. Program halted.")
                        sys.exit(0)

    ...

最后是对PyTorch进行设置:

  • device
  • 随机种子
  • causes cuDNN to benchmark multiple convolution algorithms and select the fastest.
def parse_args_and_config():
    ...

    # add device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    logging.info("Using device: {}".format(device))
    new_config.device = device

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args, new_config

至此,就基本结束main.py的学习了,后面讲进入Diffusion类中查看具体初始化、训练、采样、测试这些函数是如何实现的了。

  • 5
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值