扩散模型(二)

本文介绍了作者使用DDIMScheduler加速扩散模型采样,并针对FasionMNIST进行了条件扩散模型的训练。同时提到了在Ubuntu环境下遇到的访问限制问题,如模型和数据的墙墙问题以及stablediffusionWebUI的访问受限。
摘要由CSDN通过智能技术生成

对扩散模型的第二次学习,由于微调和引导的数据集没加载下来,本次就做了两个部分:首先利用DDIMScheduler加速采样,学习了采样的过程;其次利用FasionMNIST做了一个条件扩散模型,学习了训练的过程。由于stable diffusion Web UI 还在搭建过程中,进阶任务,使用upscaler等超分模块高清化生成的图像还未完成。
感觉任务还是偏难做的,很多model以及数据被墙掉了,而且在ubuntu里操作,不太容易获取到。此外,stable diffusion WebUI的搭建也是被墙掉的,难受。。

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from diffusers import DDIMScheduler, DDPMPipeline
from datasets import load_dataset
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
image_pipe = DDPMPipeline.from_pretrained('models/ddpm-celebahq-256')
image_pipe.to(device)
Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]





DDPMPipeline {
  "_class_name": "DDPMPipeline",
  "_diffusers_version": "0.21.2",
  "_name_or_path": "models/ddpm-celebahq-256",
  "scheduler": [
    "diffusers",
    "DDPMScheduler"
  ],
  "unet": [
    "diffusers",
    "UNet2DModel"
  ]
}
images = image_pipe().images
images[0]
  0%|          | 0/1000 [00:00<?, ?it/s]

在这里插入图片描述

# DDIM faster sampling with DDIM
scheduler = DDIMScheduler.from_pretrained('models/ddpm-celebahq-256')
scheduler.set_timesteps(num_inference_steps=40)
scheduler.timesteps
# from diffusers import DDPMScheduler
# scheduler = DDPMScheduler(num_train_timesteps=1000)
# # scheduler.timesteps
tensor([975, 950, 925, 900, 875, 850, 825, 800, 775, 750, 725, 700, 675, 650,
        625, 600, 575, 550, 525, 500, 475, 450, 425, 400, 375, 350, 325, 300,
        275, 250, 225, 200, 175, 150, 125, 100,  75,  50,  25,   0])
x = torch.randn(4, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
    model_input = scheduler.scale_model_input(x, t)
    with torch.no_grad():
        noise_pred = image_pipe.unet(model_input, t)['sample']
        scheduler_output = scheduler.step(noise_pred, t, x)
        x = scheduler_output.prev_sample  # ?
    
        if i % 10 ==0 or i == len(scheduler.timesteps)-1:
            fig, axs = plt.subplots(1,2,figsize=(12,5))
            grid = torchvision.utils.make_grid(x, nrow=4).permute(1,2,0)
            axs[0].imshow(grid.cpu().clip(-1,1)*0.5+0.3)  # why clip. 
            axs[0].set_title(f'current x (step{i})')
            
            pred_x0 = scheduler_output.pred_original_sample
            grid = torchvision.utils.make_grid(pred_x0, nrow=4).permute(1,2,0)
            axs[1].imshow(grid.cpu().clip(-1,1)*0.5+0.3)
            axs[1].set_title(f"predicted denoised images (step{i})")
            plt.show()
            
0it [00:00, ?it/s]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

image_pipe.scheduler=scheduler
images= image_pipe(num_inference_steps=40).images
images[0]
  0%|          | 0/40 [00:00<?, ?it/s]

在这里插入图片描述

# fine tuned  
### can not download butterflies
from torch.utils.data import DataLoader
dataset = torchvision.datasets.MNIST(root='./data', download=True, train=True, transform=torchvision.transforms.ToTensor())
# dataset = torchvision.datasets.FashionMNIST(root='./data', download=True, train=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print(x.shape, len(train_dataloader))
print(y)
plt.imshow(torchvision.utils.make_grid(x*2-1)[0], cmap='Greys')
torch.Size([8, 1, 28, 28]) 7500
tensor([3, 8, 3, 9, 8, 4, 5, 7])





<matplotlib.image.AxesImage at 0x7f7622225d20>

在这里插入图片描述

from torch import nn
from diffusers import UNet2DModel
class ClassConditionedUnet(nn.Module):
    def __init__(self, num_classes=10, class_emb_size=4):
        super().__init__()
        self.class_emb = nn.Embedding(num_classes, class_emb_size)
        self.model = UNet2DModel(sample_size=28, in_channels=1+class_emb_size, out_channels=1, layers_per_block=2,
                                 block_out_channels=(32, 64, 64),
                                 down_block_types=(
                                     'DownBlock2D',
                                     'AttnDownBlock2D',
                                     'AttnDownBlock2D',
                                 ),
                                 up_block_types=(
                                     'AttnUpBlock2D',
                                     'AttnUpBlock2D',
                                     'UpBlock2D',
                                 )
                                 )
    def forward(self, x, t, class_labels):
        bs, ch, w, h = x.shape
        class_cond = self.class_emb(class_labels) # bs, 4
        # print(class_cond.shape)
        class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
        # print(class_cond.shape)
        net_input = torch.cat([x, class_cond], 1)
        return self.model(net_input, t).sample
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader =  DataLoader(dataset, batch_size=128, shuffle=True)
n_epoch=10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(),lr=1e-3)
losses = []

for epoch in range(n_epoch):
    for x,y in tqdm(train_dataloader):
        x = x.to(device)*2-1
        y=y.to(device)
        noise = torch.randn_like(x)  # randn.
        timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
        
        pred = net(noisy_x, timesteps, y)
        loss = loss_fn(pred, noise)
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())
        
    avg_loss = sum(losses[-100:]) / 100
    print(f"finished epoch {epoch}, average loss of the last 100 :{avg_loss:05f}")
plt.plot(losses)
  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 0, average loss of the last 100 :0.053991



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 1, average loss of the last 100 :0.047291



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 2, average loss of the last 100 :0.044225



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 3, average loss of the last 100 :0.042341



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 4, average loss of the last 100 :0.041912



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 5, average loss of the last 100 :0.040772



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 6, average loss of the last 100 :0.040425



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 7, average loss of the last 100 :0.039488



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 8, average loss of the last 100 :0.040109



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 9, average loss of the last 100 :0.039061





[<matplotlib.lines.Line2D at 0x7f748d250d90>]

在这里插入图片描述

x = torch.randn(80, 1, 28, 28).to(device)
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)

for i,t in tqdm(enumerate(noise_scheduler.timesteps)):
    with torch.no_grad():
        residual = net(x, t, y)
        
    x = noise_scheduler.step(residual,t,x).prev_sample

fig, ax = plt.subplots(1,1, figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1,1), nrow=8)[0], cmap='Greys')
0it [00:00, ?it/s]





<matplotlib.image.AxesImage at 0x7f748d2067a0>

在这里插入图片描述

fashion_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
dd = DataLoader(fashion_dataset, batch_size=8,shuffle=True)
x,y = next(iter(dd))
print(x.shape, y.shape)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
torch.Size([8, 1, 28, 28]) torch.Size([8])





<matplotlib.image.AxesImage at 0x7f748d21d960>

在这里插入图片描述

这里使用上面训练MNIST的参数训练FashionMNIST,得到的loss和采样后的结果如下:

在这里插入图片描述
在这里插入图片描述

调整了一下参数尝试,即将学习率调整为1e-4,训练的结果如下。loss是震荡的。。但最后生成的结果,似乎也还行,感觉是两种风格,学习率小,衣服颜色会深。
fashion_dataloader = DataLoader(fashion_dataset, batch_size=128, shuffle=True)
f_losses= []
f_opt = torch.optim.Adam(net.parameters(), lr=1e-4)
for epoch in range(n_epoch):
    for fx,fy in tqdm(fashion_dataloader):
        fx = fx.to(device)*2-1
        fy = fy.to(device)
        f_noise = torch.randn_like(fx)
        f_timesteps=torch.randint(0,999, (fx.shape[0],)).long().to(device)
        noise_fx = noise_scheduler.add_noise(fx,f_noise, f_timesteps)
        f_noise_pred = net(noise_fx, f_timesteps, fy)
        f_loss = loss_fn(f_noise_pred, f_noise)
        
        f_opt.zero_grad()
        f_loss.backward()
        f_opt.step()
        f_losses.append(f_loss.item())
    avg_f_loss = sum(f_losses[-100:]) / 100
    print(f"finished epoch {epoch}, average loss of the last 100 :{avg_f_loss:05f}")
plt.plot(f_losses)     
    
  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 0, average loss of the last 100 :0.064354



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 1, average loss of the last 100 :0.063139



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 2, average loss of the last 100 :0.062826



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 3, average loss of the last 100 :0.063266



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 4, average loss of the last 100 :0.063906



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 5, average loss of the last 100 :0.064618



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 6, average loss of the last 100 :0.064006



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 7, average loss of the last 100 :0.063150



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 8, average loss of the last 100 :0.063838



  0%|          | 0/469 [00:00<?, ?it/s]


finished epoch 9, average loss of the last 100 :0.063957





[<matplotlib.lines.Line2D at 0x7f748c8d13f0>]

在这里插入图片描述

fx=torch.randn(80,1,28,28).to(device)
fy=torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    with torch.no_grad():
        residual = net(fx, t, fy)
    fx = noise_scheduler.step(residual, t, fx).prev_sample
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(torchvision.utils.make_grid(fx.detach().cpu().clip(-1,1), nrows=8)[0], cmap="Greys")
0it [00:00, ?it/s]





<matplotlib.image.AxesImage at 0x7f748c8d2d70>

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>