第四章 Diffusers 实战
安装Difffusers 库
pip install -qq -U diffusers datasets transformers accelerate ftfy pyarrow
扩散模型调度器
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
定义扩散模型
from diffusers import UNet2DModel
def model():
model = UNet2DModel(
sample_size = 240,
in_channels = 4,
out_channels = 4,
layers_per_block = 2,
block_out_channels = (64,128,128,256),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
)
)
return model
创建扩散模型训练循环
import torch.utils.data.dataset
import torchvision
from dataset import dataset_brats_2D
from torchvision import transforms
from diffusers import DDPMScheduler
import model
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import time
if __name__ == "__main__":
device = torch.device('cuda')
dataset = #自定义dataset
train_dl = DataLoader(dataset, 128, False, num_workers=1)
timesteps = torch.linspace(0, 1000, 2).long().to(device)
model = model.model().to(device)
model = torch.nn.DataParallel(model, device_ids=[0])
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)
losses = []
loss_flag = 10e+10
for epoch in range(100):
for step, batch in enumerate(train_dl):
clean_images = batch.to(device)
noise = torch.randn(clean_images.shape).to(device)
batch_size = clean_images.shape[0]
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (batch_size,),device=device
).long()
noisy_images = noise_scheduler.add_noise(clean_images,noise,timesteps)
noisy_pred = model(noisy_images,timesteps,return_dict=False)[0]
loss = F.mse_loss(noisy_pred, noise)
loss.backward(loss)
losses.append(loss.item())
optimizer.step()
optimizer.zero_grad()
if (epoch +1) % 5 == 0:
loss_last_epoch = sum(losses[-len(train_dl) :]) / len(train_dl)
print(f"Epoch:{epoch + 1}, loss:{loss_last_epoch}")
state = {'net': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
if loss_flag < loss:
torch.save(state,"best.pth")
loss_flag = loss
图像的生成
import time
import torchvision.utils
from diffusers import DDPMPipeline,DDPMScheduler
import cv2
import torch
import torchvision
from PIL import Image
import model
import numpy as np
import time
def show_images(x):
x = x * 0.5 + 0.5
grid = torchvision.utils.make_grid(x)
grid_im = grid.detach().cpu().permute(1,2,0).clip(0,1) *255
grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
return grid_im
if __name__ == "__main__":
device = torch.device('cuda')
sample = torch.randn(1, 4, 240, 240).to(device)
model = model.model().to(device)
ckpt = torch.load(r"")#自己的checkpoint
model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['net'].items() if k.startswith('module.')})
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)
flag = 0
for k in range(10000):
# start = time.time()
sample = torch.randn(1, 4, 240, 240).to(device)
for i,t in enumerate(noise_scheduler.timesteps):
print(t)
with torch.no_grad():
residual = model(sample,t).sample
sample = noise_scheduler.step(residual, t, sample).prev_sample
time_flag = time.time()
print(sample.shape)
image = show_images(sample[0][0])
image.save(str(time_flag) +'_0'+'.png')