《扩散模型 从原理到实战》Hugging Face (三)

第四章 Diffusers 实战

安装Difffusers 库

pip install -qq -U diffusers datasets transformers accelerate ftfy pyarrow

扩散模型调度器

from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

定义扩散模型

from diffusers import  UNet2DModel

def model():
    model = UNet2DModel(
            sample_size = 240,
            in_channels = 4,
            out_channels = 4,
            layers_per_block = 2,
            block_out_channels = (64,128,128,256),
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",
                "AttnDownBlock2D",
            ),
            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            )
        )
    return model

创建扩散模型训练循环

import torch.utils.data.dataset
import torchvision
from dataset import dataset_brats_2D
from torchvision import transforms
from diffusers import DDPMScheduler
import model
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import time

if __name__ == "__main__":

    device = torch.device('cuda')
    dataset = #自定义dataset
    train_dl = DataLoader(dataset, 128, False, num_workers=1)

    timesteps = torch.linspace(0, 1000, 2).long().to(device)
    model = model.model().to(device)
    model = torch.nn.DataParallel(model, device_ids=[0])
    noise_scheduler = DDPMScheduler(
        num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)
    losses = []
    loss_flag = 10e+10
    for epoch in range(100):
        for step, batch in enumerate(train_dl):
            clean_images = batch.to(device)
            noise = torch.randn(clean_images.shape).to(device)
            batch_size = clean_images.shape[0]
            timesteps = torch.randint(
                0, noise_scheduler.num_train_timesteps, (batch_size,),device=device
            ).long()
            noisy_images = noise_scheduler.add_noise(clean_images,noise,timesteps)
            noisy_pred = model(noisy_images,timesteps,return_dict=False)[0]
            loss = F.mse_loss(noisy_pred, noise)
            loss.backward(loss)
            losses.append(loss.item())

            optimizer.step()
            optimizer.zero_grad()
            if (epoch +1) % 5 == 0:
                loss_last_epoch = sum(losses[-len(train_dl) :]) / len(train_dl)
                print(f"Epoch:{epoch + 1}, loss:{loss_last_epoch}")

            state = {'net': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
            if loss_flag < loss:
                torch.save(state,"best.pth")
                loss_flag = loss

图像的生成

import time
import torchvision.utils
from diffusers import DDPMPipeline,DDPMScheduler
import cv2
import torch
import torchvision
from PIL import Image
import model
import numpy as np
import time
def show_images(x):
    x = x * 0.5 + 0.5
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1,2,0).clip(0,1) *255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

if __name__ == "__main__":
    device = torch.device('cuda')
    sample = torch.randn(1, 4, 240, 240).to(device)
    model = model.model().to(device)
    ckpt = torch.load(r"")#自己的checkpoint
    model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['net'].items() if k.startswith('module.')})

    noise_scheduler = DDPMScheduler(
        num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
    )
    flag = 0
    for k in range(10000):
        # start = time.time()
        sample = torch.randn(1, 4, 240, 240).to(device)
        for i,t in enumerate(noise_scheduler.timesteps):

            print(t)
            with torch.no_grad():
                residual = model(sample,t).sample
            sample = noise_scheduler.step(residual, t, sample).prev_sample
        time_flag = time.time()
        print(sample.shape)
        image = show_images(sample[0][0])
        image.save(str(time_flag) +'_0'+'.png')
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值