使用DDPM(扩散模型)训练自己的数据集实现数据集扩容pytorch

本文介绍了DDPM扩散模型的原理,包括前向过程和反向过程,以及如何利用变分推断进行建模。作者提供了一个基于Python的训练脚本`train_unconditional.py`,该脚本使用HuggingFace的Diffusers库训练模型,处理柑橘病害叶片数据集。训练过程中涉及数据预处理、模型定义、优化器设置和损失函数计算。此外,还提供了一个生成图像的脚本`generate.py`,用于从训练好的模型生成合成图像。
该文章已生成可运行项目,

一、简介

DDPM扩散模型包括两个过程:前向过程(forward process)和反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process)。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可以用来生成数据,可通过变分推断来进行建模和求解。在DDPM中,通过连续添加高斯噪声来破坏训练数据,然后通过反转这个噪声过程,来学习恢复数据。测试时,可以使用DDPM将随机采样的噪声传入模型中,通过学习去噪过程来生成数据。

二、实战

首先,下载安装指定的python包

pip install diffusers

我使用的数据集为柑橘病害叶片数据集(未开源),三种类型,分别为黄龙病、缺镁、正常三种性状。

我使用数据集格式:(数据集不需要划分为train、val、test)

dataset_orgin
--Huanglong_disease
----0.jpg
----1.jpg
----~~~~~
--Magnesium_deficiency
----0.jpg
----1.jpg
----~~~~~
--Normal
----0.jpg
----1.jpg
----~~~~~
新建python文件,用于训练

train_unconditional.py
import argparse
import inspect
import math
import os
from pathlib import Path
from typing import Optional

import torch
import torch.nn.functional as F

from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, whoami
from torchvision.transforms import (
    CenterCrop,
    Compose,
    InterpolationMode,
    Normalize,
    RandomHorizontalFlip,
    Resize,
    ToTensor,
)
from tqdm.auto import tqdm


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')   # 设备

logger = get_logger(__name__)


def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    if not isinstance(arr, torch.Tensor):
        arr = torch.from_numpy(arr)
    res = arr[timesteps].float().to(timesteps.device)
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)


def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=r'D:\pyCharmdata\datasets_orgin\Huanglong_disease',#None
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that HF Datasets can understand."
        ),
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,#None
        help="The config of the Dataset, leave as None if there's only one config.",
    )
    parser.add_argument(
        "--train_data_dir",
        type=str,
        default=r'D:\pyCharmdata\datasets_orgin\Huanglong_disease',# None
        help=(
            "A folder containing the training data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="ddpm-model-128",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--overwrite_output_dir", action="store_true")
    parser.ad
本文章已经生成可运行项目
评论 37
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Leonard2021

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值