【LatentDiffusion 代码详解(1)】LatentDiffusion 的 yaml 解读

54 篇文章 2 订阅
31 篇文章 0 订阅

YAML 文件提供了一种清晰、简洁且易于理解的方式来描述配置信息,特别适用于机器学习模型的超参数调优和实验管理。

以 Latent Diffusion 官方代码仓库中的 https://github.com/CompVis/latent-diffusion/blob/main/configs/autoencoder/autoencoder_kl_32x32x4.yaml 为例(如下),该 YAML 配置文件,用于定义训练一个自编码器模型的设置,其中包含 3 个部分:

  1. model (AutoencoderKL的模型结构)
  2. data(DataModuleFromConfig中如何读入数据)
  3. lightning(设置回调函数和训练器)
model:
  base_learning_rate: 4.5e-6
  target: ldm.models.autoencoder.AutoencoderKL
  params:
    monitor: "val/rec_loss"
    embed_dim: 4
    lossconfig:
      target: ldm.modules.losses.LPIPSWithDiscriminator
      params:
        disc_start: 50001
        kl_weight: 0.000001
        disc_weight: 0.5

    ddconfig:
      double_z: True
      z_channels: 4
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult: [ 1,2,4,4 ]  # num_down = len(ch_mult)-1
      num_res_blocks: 2
      attn_resolutions: [ ]
      dropout: 0.0

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 12
    wrap: True
    train:
      target: ldm.data.imagenet.ImageNetSRTrain
      params:
        size: 256
        degradation: pil_nearest
    validation:
      target: ldm.data.imagenet.ImageNetSRValidation
      params:
        size: 256
        degradation: pil_nearest

lightning:
  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 1000
        max_images: 8
        increase_log_steps: True

  trainer:
    benchmark: True
    accumulate_grad_batches: 2

Model

  • base_learning_rate: 4.5e-6: 这是基础学习率,用于优化器的初始化。学习率表示在每次参数更新时,参数被调整的程度。
  • target: ldm.models.autoencoder.AutoencoderKL: 这是要训练的模型的类路径,即模型定义代码所在的位置。
  • params: 这里是模型的参数设置。
    • monitor: "val/rec_loss": 监控的指标,通常是验证集上的重构损失。
    • embed_dim: 4: 嵌入维度,可能是自编码器中隐藏层的维度。
    • lossconfig: 损失函数的配置。
      • target: ldm.modules.losses.LPIPSWithDiscriminator: LPIPS损失所在位置。在这里插入图片描述

      • params: 参数设置。

        • disc_start: 50001: 鉴别器开始的步数。
        • kl_weight: 0.000001: KL散度的权重。
        • disc_weight: 0.5: 鉴别器权重。
    • ddconfig: 双向变换的配置。
      • double_z: True: 是否使用双向Z变换。
      • 其他参数是有关双向变换网络结构的设置,包括通道数量、分辨率、残差块数量等。

Data

  • target: main.DataModuleFromConfig: 数据模块的类路径。
  • params: 数据加载器的参数设置。
    • batch_size: 12: 批量大小,即每次迭代训练时传递给模型的样本数量。
    • wrap: True: 是否循环迭代数据。
    • train: 训练数据的设置。
      • target: ldm.data.imagenet.ImageNetSRTrain: 训练集加载器的类路径。
      • params: 参数设置。
        • size: 256: 数据的大小。
        • degradation: pil_nearest: 图像降质方法。
    • validation: 验证集的设置。
      • target: ldm.data.imagenet.ImageNetSRValidation: 验证数据加载器的类路径。
      • params: 参数设置,与训练数据类似。

Lightning

  • callbacks: 回调函数的设置。
    • image_logger: 图像记录器的设置。
      • target: main.ImageLogger: 图像记录器的类路径。
      • params: 参数设置。
        • batch_frequency: 1000: 记录图像的频率。
        • max_images: 8: 最大图像数量。
        • increase_log_steps: True: 是否逐步增加日志步骤。
  • trainer: 训练器设置。
    • benchmark: True: 是否启用性能测试。
    • accumulate_grad_batches: 2: 梯度累积的步骤数量,用于处理较大的批次大小。
  • 21
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值