对扩散模型的第二次学习,由于微调和引导的数据集没加载下来,本次就做了两个部分:首先利用DDIMScheduler加速采样,学习了采样的过程;其次利用FasionMNIST做了一个条件扩散模型,学习了训练的过程。由于stable diffusion Web UI 还在搭建过程中,进阶任务,使用upscaler等超分模块高清化生成的图像还未完成。
感觉任务还是偏难做的,很多model以及数据被墙掉了,而且在ubuntu里操作,不太容易获取到。此外,stable diffusion WebUI的搭建也是被墙掉的,难受。。
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from diffusers import DDIMScheduler, DDPMPipeline
from datasets import load_dataset
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
image_pipe = DDPMPipeline.from_pretrained('models/ddpm-celebahq-256')
image_pipe.to(device)
Loading pipeline components...: 0%| | 0/2 [00:00<?, ?it/s]
DDPMPipeline {
"_class_name": "DDPMPipeline",
"_diffusers_version": "0.21.2",
"_name_or_path": "models/ddpm-celebahq-256",
"scheduler": [
"diffusers",
"DDPMScheduler"
],
"unet": [
"diffusers",
"UNet2DModel"
]
}
images = image_pipe().images
images[0]
0%| | 0/1000 [00:00<?, ?it/s]

# DDIM faster sampling with DDIM
scheduler = DDIMScheduler.from_pretrained('models/ddpm-celebahq-256')
scheduler.set_timesteps(num_inference_steps=40)
scheduler.timesteps
# from diffusers import DDPMScheduler
# scheduler = DDPMScheduler(num_train_timesteps=1000)
# # scheduler.timesteps
tensor([975, 950, 925, 900, 875, 850, 825, 800, 775, 750, 725, 700, 675, 650,
625, 600, 575, 550, 525, 500, 475, 450, 425, 400, 375, 350, 325, 300,
275, 250, 225, 200, 175, 150, 125, 100, 75, 50, 25, 0])
x = torch.randn(4, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)['sample']
scheduler_output = scheduler.step(noise_pred, t, x)
x = scheduler_output.prev_sample # ?
if i % 10 ==0 or i == len(scheduler.timesteps)-1:
fig, axs = plt.subplots(1,2,figsize=(12,5))
grid = torchvision.utils.make_grid(x, nrow=4).permute(1,2,0)
axs[0].imshow(grid.cpu().clip(-1,1)*0.5+0.3) # why clip.
axs[0].set_title(f'current x (step{i})')
pred_x0 = scheduler_output.pred_original_sample
grid = torchvision.utils.make_grid(pred_x0, nrow=4).permute(1,2,0)
axs[1].imshow(grid.cpu().clip(-1,1)*0.5+0.3)
axs[1].set_title(f"predicted denoised images (step{i})")
plt.show()
0it [00:00, ?it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

image_pipe.scheduler=scheduler
images= image_pipe(num_inference_steps=40).images
images[0]
0%| | 0/40 [00:00<?, ?it/s]

# fine tuned
### can not download butterflies
from torch.utils.data import DataLoader
dataset = torchvision.datasets.MNIST(root='./data', download=True, train=True, transform=torchvision.transforms.ToTensor())
# dataset = torchvision.datasets.FashionMNIST(root='./data', download=True, train=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print(x.shape, len(train_dataloader))
print(y)
plt.imshow(torchvision.utils.make_grid(x*2-1)[0], cmap='Greys')
torch.Size([8, 1, 28, 28]) 7500
tensor([3, 8, 3, 9, 8, 4, 5, 7])
<matplotlib.image.AxesImage at 0x7f7622225d20>

from torch import nn
from diffusers import UNet2DModel
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=10, class_emb_size=4):
super().__init__()
self.class_emb = nn.Embedding(num_classes, class_emb_size)
self.model = UNet2DModel(sample_size=28, in_channels=1+class_emb_size, out_channels=1, layers_per_block=2,
block_out_channels=(32, 64, 64),
down_block_types=(
'DownBlock2D',
'AttnDownBlock2D',
'AttnDownBlock2D',
),
up_block_types=(
'AttnUpBlock2D',
'AttnUpBlock2D',
'UpBlock2D',
)
)
def forward(self, x, t, class_labels):
bs, ch, w, h = x.shape
class_cond = self.class_emb(class_labels) # bs, 4
# print(class_cond.shape)
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
# print(class_cond.shape)
net_input = torch.cat([x, class_cond], 1)
return self.model(net_input, t).sample
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
n_epoch=10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(),lr=1e-3)
losses = []
for epoch in range(n_epoch):
for x,y in tqdm(train_dataloader):
x = x.to(device)*2-1
y=y.to(device)
noise = torch.randn_like(x) # randn.
timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
pred = net(noisy_x, timesteps, y)
loss = loss_fn(pred, noise)
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
avg_loss = sum(losses[-100:]) / 100
print(f"finished epoch {epoch}, average loss of the last 100 :{avg_loss:05f}")
plt.plot(losses)
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 0, average loss of the last 100 :0.053991
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 1, average loss of the last 100 :0.047291
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 2, average loss of the last 100 :0.044225
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 3, average loss of the last 100 :0.042341
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 4, average loss of the last 100 :0.041912
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 5, average loss of the last 100 :0.040772
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 6, average loss of the last 100 :0.040425
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 7, average loss of the last 100 :0.039488
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 8, average loss of the last 100 :0.040109
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 9, average loss of the last 100 :0.039061
[<matplotlib.lines.Line2D at 0x7f748d250d90>]

x = torch.randn(80, 1, 28, 28).to(device)
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)
for i,t in tqdm(enumerate(noise_scheduler.timesteps)):
with torch.no_grad():
residual = net(x, t, y)
x = noise_scheduler.step(residual,t,x).prev_sample
fig, ax = plt.subplots(1,1, figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1,1), nrow=8)[0], cmap='Greys')
0it [00:00, ?it/s]
<matplotlib.image.AxesImage at 0x7f748d2067a0>

fashion_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
dd = DataLoader(fashion_dataset, batch_size=8,shuffle=True)
x,y = next(iter(dd))
print(x.shape, y.shape)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
torch.Size([8, 1, 28, 28]) torch.Size([8])
<matplotlib.image.AxesImage at 0x7f748d21d960>

这里使用上面训练MNIST的参数训练FashionMNIST,得到的loss和采样后的结果如下:


调整了一下参数尝试,即将学习率调整为1e-4,训练的结果如下。loss是震荡的。。但最后生成的结果,似乎也还行,感觉是两种风格,学习率小,衣服颜色会深。
fashion_dataloader = DataLoader(fashion_dataset, batch_size=128, shuffle=True)
f_losses= []
f_opt = torch.optim.Adam(net.parameters(), lr=1e-4)
for epoch in range(n_epoch):
for fx,fy in tqdm(fashion_dataloader):
fx = fx.to(device)*2-1
fy = fy.to(device)
f_noise = torch.randn_like(fx)
f_timesteps=torch.randint(0,999, (fx.shape[0],)).long().to(device)
noise_fx = noise_scheduler.add_noise(fx,f_noise, f_timesteps)
f_noise_pred = net(noise_fx, f_timesteps, fy)
f_loss = loss_fn(f_noise_pred, f_noise)
f_opt.zero_grad()
f_loss.backward()
f_opt.step()
f_losses.append(f_loss.item())
avg_f_loss = sum(f_losses[-100:]) / 100
print(f"finished epoch {epoch}, average loss of the last 100 :{avg_f_loss:05f}")
plt.plot(f_losses)
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 0, average loss of the last 100 :0.064354
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 1, average loss of the last 100 :0.063139
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 2, average loss of the last 100 :0.062826
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 3, average loss of the last 100 :0.063266
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 4, average loss of the last 100 :0.063906
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 5, average loss of the last 100 :0.064618
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 6, average loss of the last 100 :0.064006
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 7, average loss of the last 100 :0.063150
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 8, average loss of the last 100 :0.063838
0%| | 0/469 [00:00<?, ?it/s]
finished epoch 9, average loss of the last 100 :0.063957
[<matplotlib.lines.Line2D at 0x7f748c8d13f0>]

fx=torch.randn(80,1,28,28).to(device)
fy=torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
with torch.no_grad():
residual = net(fx, t, fy)
fx = noise_scheduler.step(residual, t, fx).prev_sample
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(torchvision.utils.make_grid(fx.detach().cpu().clip(-1,1), nrows=8)[0], cmap="Greys")
0it [00:00, ?it/s]
<matplotlib.image.AxesImage at 0x7f748c8d2d70>

本文介绍了作者使用DDIMScheduler加速扩散模型采样,并针对FasionMNIST进行了条件扩散模型的训练。同时提到了在Ubuntu环境下遇到的访问限制问题,如模型和数据的墙墙问题以及stablediffusionWebUI的访问受限。
662

被折叠的 条评论
为什么被折叠?



