图解Diffusion扩散模型+代码

0、项目视频详解

视频教程见B站https://www.bilibili.com/video/BV1e8411a7mz

1、diffusion模型理论(推导出损失函数)

1.1、背景

随着人工智能在图像生成,文本生成以及多模态生成等领域的技术不断累积,如:生成对抗网络(GAN)、变微分自动编码器(VAE)、normalizing flow models、自回归模型(AR)、energy-based models以及近年来大火的扩散模型(Diffusion Model)。

扩散模型的成功并非横空出世一般,突然出现在人们的视野中。其实早在2015年就已有人提出相类似的想法,最终在2020年提出了我们所熟知的“denoising diffusion probabilistic models”。DDPM

近期的novelai的生成技术同样是基于扩散模型,以下可以看到其强大的生成效果。可在此处跳转进行玩耍。

本项目可以达到的效果如下。输入向日葵,cfg=7的结果。可以看到,效果已经比较不错了。 

1.2、模型训练与采样的算法流程

先放个图,1.3和1.4进行具体的流程与公式推导。我们要做的就是要推导出训练过程中的损失函数。

1.3、前向噪声扩散公式推导

diffusion模型的前向过程是向原始图片中逐步的添加高斯噪声,直至最后的图像趋于高斯分布。由于噪声占比会越来越大,所以添加噪声的强度也会越来越大。如下图所示:

  • 每一时刻的图像都由前一时刻的图像添加噪声得到

  • 最后的图像会变成纯噪声

  • 每一时刻的添加的噪声强度均不同,目前有线性调度器,余弦调度器等

  • 这一过程构建了我们训练所用到的标签,后面会看到

下面的推导过程展示了,我们如何从初始图像直接得到第t时刻的图像。

这个公式为下面的推导打上一个铺垫,下面一节就是关键的损失函数推导了。

1.4、优化目标,损失函数推导

上面的正向扩散并不难,下面我们推导反向扩散过程。即由Xt到Xt-1。

#

2、非条件生成(随机生成图片)

使用stanford汽车图片为例,没有类别。

2.1、训练过程解析

我们使用前向过程采样得到标签,训练时使用Unet网络结构,同时在模型的输入中嵌入时间步的编码。这类似于transformer模型中的位置编码,让模型更容易训练。 如下图所示: 

2.2、数据解压

解压我们的数据集。只需要首次运行该项目时解压即可!

In [13]

import os
if not os.path.exists("work/cars"):
    !mkdir work/cars
!unzip -oq data/data173302/stanford_cars.zip -d work/cars

In [14]

# 删除多余文件
!rm -rf work/cars/cars_test
!rm -rf work/cars/devkit
!rm -rf work/cars/car_devkit.tgz
!rm -rf work/cars/cars_train.tgz
!rm -rf work/cars/cars_test.tgz
!rm -rf work/cars/cars_test_annos_withlabels.mat

2.3、数据展示

查看我们的汽车图片。

In [1]

import paddle
import paddle.vision
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

# 定义展示图片函数
def show_images(imgs_paths=[],cols=4):
    num_samples = len(imgs_paths)
    plt.figure(figsize=(15,15))
    i = 0
    for img_path in imgs_paths:
        img = Image.open(img_path)
        plt.subplot(int(num_samples/cols + 1), cols, i + 1)
        plt.imshow(img)
        i += 1

imgs_paths = [
    "work/cars/cars_train/05930.jpg", "work/cars/cars_train/06816.jpg", "work/cars/cars_train/02885.jpg", "work/cars/cars_train/07471.jpg",
    "work/cars/cars_train/06600.jpg", "work/cars/cars_train/06020.jpg", "work/cars/cars_train/04818.jpg", "work/cars/cars_train/06088.jpg"
]
show_images(imgs_paths)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data

<Figure size 1500x1500 with 8 Axes>

2.4、构建数据集

我们使用paddle.vision里的数据集接口即可。

In [2]

import os
import paddle
import paddle.nn as nn
import paddle.vision as V
from PIL import Image
from matplotlib import pyplot as plt
from paddle.io import DataLoader

# 这里我们不需要用到图像标签,可以直接用paddle.vision里面提供的数据集接口
def get_data(args):
    transforms = V.transforms.Compose([
        V.transforms.Resize(80),  # args.image_size + 1/4 *args.image_size
        V.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
        V.transforms.ToTensor(),
        V.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    dataset = V.datasets.ImageFolder(args.dataset_path, transform=transforms)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    return dataloader

2.5、训练流程

训练中我们可以修改ARGS类的参数进行超参数定义。基本上,只要知道我们的损失函数是两张图片之间的均方误差,代码部分会变得比较简单。对比GAN而言,diffusion的参数更加容易调整,也更容易训练。

In [3]

"""ddpm"""

import os
import paddle
import paddle.nn as nn
from matplotlib import pyplot as plt
%matplotlib inline
from tqdm import tqdm
from paddle import optimizer
# from utils import *
from modules import UNet    # 模型
import logging
import numpy as np

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")


class Diffusion:
    def __init__(self, noise_steps=500, beta_start=1e-4, beta_end=0.02, img_size=64, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.prepare_noise_schedule()
        self.alpha = 1. - self.beta
        self.alpha_hat = paddle.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return paddle.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = paddle.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = paddle.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = paddle.randn(shape=x.shape)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return paddle.randint(low=1, high=self.noise_steps, shape=(n,))

    def sample(self, model, n):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with paddle.no_grad():
            x = paddle.randn((n, 3, self.img_size, self.img_size))
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):


                t = paddle.to_tensor([i] * x.shape[0]).astype("int64")
                # print(x.shape, t.shape)

                # print(f"完成第{i}步")
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = paddle.randn(shape=x.shape)
                else:
                    noise = paddle.zeros_like(x)
                x = 1 / paddle.sqrt(alpha) * (x - ((1 - alpha) / (paddle.sqrt(1 - alpha_hat))) * predicted_noise) + paddle.sqrt(beta) * noise
        model.train()
        x = (x.clip(-1, 1) + 1) / 2
        x = (x * 255)
        return x

def train(args):
    # setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)

    image = next(iter(dataloader))[0]

    model = UNet()
    opt = optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    # logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, images in enumerate(pbar):
            # print(images)
            t = diffusion.sample_timesteps(images[0].shape[0])
            x_t, noise = diffusion.noise_images(images[0], t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)  # 损失函数

            opt.clear_grad()
            loss.backward()
            opt.step()

            pbar.set_postfix(MSE=loss.item())

            # print(("MSE", loss.item(), "global_step", epoch * l + i))
            # logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
        
        if epoch % 20 == 0:
            paddle.save(model.state_dict(), f"car_models/ddpm_uncond{epoch}.pdparams")
            sampled_images = diffusion.sample(model, n=8)

            for i in range(8):
                img = sampled_images[i].transpose([1, 2, 0])
                img = np.array(img).astype("uint8")
                plt.subplot(2,4,i+1)
                plt.imshow(img)
            plt.show()

def launch():
    import argparse

    # 参数设置
    class ARGS:
        def __init__(self):
            self.run_name = "DDPM_Uncondtional"
            self.epochs = 150
            self.batch_size = 24
            self.image_size = 64
            self.dataset_path = r"/home/aistudio/work/cars"
            self.device = "cuda"
            self.lr = 1.5e-4

    args = ARGS()
    train(args)


if __name__ == '__main__':
    launch()
    pass
W1024 11:03:25.091079   573 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1024 11:03:25.094197   573 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
11:03:25 - INFO: Starting epoch 0:
100%|██████████| 340/340 [02:13<00:00,  3.70it/s, MSE=0.15] 
11:05:39 - INFO: Sampling 8 new images....
499it [00:20, 23.93it/s]

<Figure size 640x480 with 8 Axes>
11:06:00 - INFO: Starting epoch 1:
100%|██████████| 340/340 [02:13<00:00,  3.11it/s, MSE=0.0725]
11:08:14 - INFO: Starting epoch 2:
100%|██████████| 340/340 [02:12<00:00,  3.37it/s, MSE=0.0777]
11:10:26 - INFO: Starting epoch 3:
100%|██████████| 340/340 [02:12<00:00,  3.44it/s, MSE=0.0814]
11:12:38 - INFO: Starting epoch 4:
100%|██████████| 340/340 [02:12<00:00,  3.30it/s, MSE=0.0579]
11:14:51 - INFO: Starting epoch 5:
100%|██████████| 340/340 [02:13<00:00,  3.40it/s, MSE=0.107] 
11:17:05 - INFO: Starting epoch 6:
100%|██████████| 340/340 [02:14<00:00,  3.49it/s, MSE=0.0742]
11:19:19 - INFO: Starting epoch 7:
100%|██████████| 340/340 [02:14<00:00,  3.20it/s, MSE=0.0422]
11:21:34 - INFO: Starting epoch 8:
100%|██████████| 340/340 [02:13<00:00,  3.26it/s, MSE=0.0527]
11:23:47 - INFO: Starting epoch 9:
100%|██████████| 340/340 [02:13<00:00,  3.45it/s, MSE=0.064] 
11:26:01 - INFO: Starting epoch 10:
100%|██████████| 340/340 [02:15<00:00,  2.91it/s, MSE=0.043] 
11:28:17 - INFO: Starting epoch 11:
100%|██████████| 340/340 [02:14<00:00,  2.60it/s, MSE=0.0712]
11:30:31 - INFO: Starting epoch 12:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.0674]
11:32:44 - INFO: Starting epoch 13:
100%|██████████| 340/340 [02:14<00:00,  3.00it/s, MSE=0.0464]
11:34:59 - INFO: Starting epoch 14:
100%|██████████| 340/340 [02:14<00:00,  2.93it/s, MSE=0.0349]
11:37:13 - INFO: Starting epoch 15:
100%|██████████| 340/340 [02:13<00:00,  3.58it/s, MSE=0.0279]
11:39:26 - INFO: Starting epoch 16:
100%|██████████| 340/340 [02:14<00:00,  2.62it/s, MSE=0.0436]
11:41:40 - INFO: Starting epoch 17:
100%|██████████| 340/340 [02:15<00:00,  3.06it/s, MSE=0.0278]
11:43:55 - INFO: Starting epoch 18:
100%|██████████| 340/340 [02:13<00:00,  3.03it/s, MSE=0.0318]
11:46:09 - INFO: Starting epoch 19:
100%|██████████| 340/340 [02:13<00:00,  3.01it/s, MSE=0.0743]
11:48:22 - INFO: Starting epoch 20:
100%|██████████| 340/340 [02:12<00:00,  3.26it/s, MSE=0.0721]
11:50:36 - INFO: Sampling 8 new images....
499it [00:20, 24.05it/s]

<Figure size 640x480 with 8 Axes>
11:50:57 - INFO: Starting epoch 21:
100%|██████████| 340/340 [02:13<00:00,  3.32it/s, MSE=0.0275]
11:53:10 - INFO: Starting epoch 22:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.028] 
11:55:24 - INFO: Starting epoch 23:
100%|██████████| 340/340 [02:13<00:00,  2.89it/s, MSE=0.0155]
11:57:37 - INFO: Starting epoch 24:
100%|██████████| 340/340 [02:13<00:00,  3.17it/s, MSE=0.0386]
11:59:51 - INFO: Starting epoch 25:
100%|██████████| 340/340 [02:13<00:00,  3.16it/s, MSE=0.0189]
12:02:04 - INFO: Starting epoch 26:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.0285]
12:04:18 - INFO: Starting epoch 27:
100%|██████████| 340/340 [02:13<00:00,  3.47it/s, MSE=0.0593]
12:06:31 - INFO: Starting epoch 28:
100%|██████████| 340/340 [02:14<00:00,  2.98it/s, MSE=0.0151]
12:08:45 - INFO: Starting epoch 29:
100%|██████████| 340/340 [02:12<00:00,  3.40it/s, MSE=0.0552]
12:10:57 - INFO: Starting epoch 30:
100%|██████████| 340/340 [02:14<00:00,  3.53it/s, MSE=0.0335]
12:13:12 - INFO: Starting epoch 31:
100%|██████████| 340/340 [02:13<00:00,  3.01it/s, MSE=0.00773]
12:15:25 - INFO: Starting epoch 32:
100%|██████████| 340/340 [02:13<00:00,  3.03it/s, MSE=0.0907]
12:17:39 - INFO: Starting epoch 33:
100%|██████████| 340/340 [02:15<00:00,  3.65it/s, MSE=0.0412]
12:19:54 - INFO: Starting epoch 34:
100%|██████████| 340/340 [02:13<00:00,  3.55it/s, MSE=0.0359]
12:22:08 - INFO: Starting epoch 35:
100%|██████████| 340/340 [02:13<00:00,  3.30it/s, MSE=0.0563]
12:24:21 - INFO: Starting epoch 36:
100%|██████████| 340/340 [02:13<00:00,  3.34it/s, MSE=0.0299]
12:26:35 - INFO: Starting epoch 37:
100%|██████████| 340/340 [02:13<00:00,  3.24it/s, MSE=0.0315] 
12:28:49 - INFO: Starting epoch 38:
100%|██████████| 340/340 [02:13<00:00,  3.08it/s, MSE=0.0455]
12:31:02 - INFO: Starting epoch 39:
100%|██████████| 340/340 [02:12<00:00,  3.23it/s, MSE=0.024] 
12:33:15 - INFO: Starting epoch 40:
100%|██████████| 340/340 [02:13<00:00,  3.32it/s, MSE=0.0416]
12:35:29 - INFO: Sampling 8 new images....
499it [00:20, 23.89it/s]

<Figure size 640x480 with 8 Axes>
12:35:50 - INFO: Starting epoch 41:
100%|██████████| 340/340 [02:13<00:00,  3.18it/s, MSE=0.0134]
12:38:03 - INFO: Starting epoch 42:
100%|██████████| 340/340 [02:12<00:00,  3.77it/s, MSE=0.0948]
12:40:16 - INFO: Starting epoch 43:
100%|██████████| 340/340 [02:13<00:00,  3.16it/s, MSE=0.0208]
12:42:30 - INFO: Starting epoch 44:
100%|██████████| 340/340 [02:13<00:00,  3.29it/s, MSE=0.0421]
12:44:44 - INFO: Starting epoch 45:
100%|██████████| 340/340 [02:13<00:00,  2.88it/s, MSE=0.0296]
12:46:57 - INFO: Starting epoch 46:
100%|██████████| 340/340 [02:12<00:00,  3.00it/s, MSE=0.0398]
12:49:10 - INFO: Starting epoch 47:
100%|██████████| 340/340 [02:13<00:00,  3.06it/s, MSE=0.0269]
12:51:24 - INFO: Starting epoch 48:
100%|██████████| 340/340 [02:12<00:00,  3.34it/s, MSE=0.0635]
12:53:37 - INFO: Starting epoch 49:
100%|██████████| 340/340 [02:12<00:00,  3.58it/s, MSE=0.0687]
12:55:49 - INFO: Starting epoch 50:
100%|██████████| 340/340 [02:12<00:00,  3.08it/s, MSE=0.0253]
12:58:01 - INFO: Starting epoch 51:
100%|██████████| 340/340 [02:12<00:00,  3.33it/s, MSE=0.0219]
01:00:14 - INFO: Starting epoch 52:
100%|██████████| 340/340 [02:12<00:00,  3.13it/s, MSE=0.0422]
01:02:27 - INFO: Starting epoch 53:
100%|██████████| 340/340 [02:12<00:00,  3.26it/s, MSE=0.0187]
01:04:39 - INFO: Starting epoch 54:
100%|██████████| 340/340 [02:14<00:00,  3.39it/s, MSE=0.0453]
01:06:54 - INFO: Starting epoch 55:
100%|██████████| 340/340 [02:14<00:00,  3.45it/s, MSE=0.101] 
01:09:08 - INFO: Starting epoch 56:
100%|██████████| 340/340 [02:15<00:00,  3.22it/s, MSE=0.016] 
01:11:23 - INFO: Starting epoch 57:
100%|██████████| 340/340 [02:14<00:00,  3.21it/s, MSE=0.0173]
01:13:38 - INFO: Starting epoch 58:
100%|██████████| 340/340 [02:13<00:00,  2.65it/s, MSE=0.0127]
01:15:52 - INFO: Starting epoch 59:
100%|██████████| 340/340 [02:14<00:00,  3.56it/s, MSE=0.112] 
01:18:06 - INFO: Starting epoch 60:
100%|██████████| 340/340 [02:14<00:00,  3.01it/s, MSE=0.0155]
01:20:21 - INFO: Sampling 8 new images....
499it [00:21, 23.74it/s]

<Figure size 640x480 with 8 Axes>
01:20:42 - INFO: Starting epoch 61:
100%|██████████| 340/340 [02:15<00:00,  3.17it/s, MSE=0.0143]
01:22:58 - INFO: Starting epoch 62:
100%|██████████| 340/340 [02:15<00:00,  3.26it/s, MSE=0.0731]
01:25:14 - INFO: Starting epoch 63:
100%|██████████| 340/340 [02:14<00:00,  3.38it/s, MSE=0.0484]
01:27:28 - INFO: Starting epoch 64:
100%|██████████| 340/340 [02:16<00:00,  3.30it/s, MSE=0.0154]
01:29:45 - INFO: Starting epoch 65:
100%|██████████| 340/340 [02:15<00:00,  3.31it/s, MSE=0.0224]
01:32:00 - INFO: Starting epoch 66:
100%|██████████| 340/340 [02:15<00:00,  3.14it/s, MSE=0.0265]
01:34:16 - INFO: Starting epoch 67:
100%|██████████| 340/340 [02:14<00:00,  3.10it/s, MSE=0.0326]
01:36:30 - INFO: Starting epoch 68:
100%|██████████| 340/340 [02:14<00:00,  3.35it/s, MSE=0.0656]
01:38:44 - INFO: Starting epoch 69:
100%|██████████| 340/340 [02:14<00:00,  3.20it/s, MSE=0.0591]
01:40:58 - INFO: Starting epoch 70:
100%|██████████| 340/340 [02:13<00:00,  3.34it/s, MSE=0.0196] 
01:43:12 - INFO: Starting epoch 71:
100%|██████████| 340/340 [02:15<00:00,  2.64it/s, MSE=0.021] 
01:45:28 - INFO: Starting epoch 72:
100%|██████████| 340/340 [02:14<00:00,  2.85it/s, MSE=0.0166]
01:47:42 - INFO: Starting epoch 73:
100%|██████████| 340/340 [02:15<00:00,  3.31it/s, MSE=0.0408]
01:49:57 - INFO: Starting epoch 74:
100%|██████████| 340/340 [02:14<00:00,  3.06it/s, MSE=0.0705] 
01:52:12 - INFO: Starting epoch 75:
100%|██████████| 340/340 [02:14<00:00,  3.06it/s, MSE=0.0326]
01:54:26 - INFO: Starting epoch 76:
100%|██████████| 340/340 [02:13<00:00,  3.55it/s, MSE=0.016] 
01:56:39 - INFO: Starting epoch 77:
100%|██████████| 340/340 [02:13<00:00,  2.98it/s, MSE=0.0122]
01:58:53 - INFO: Starting epoch 78:
100%|██████████| 340/340 [02:13<00:00,  3.57it/s, MSE=0.0304]
02:01:06 - INFO: Starting epoch 79:
100%|██████████| 340/340 [02:14<00:00,  3.17it/s, MSE=0.0186]
02:03:21 - INFO: Starting epoch 80:
100%|██████████| 340/340 [02:14<00:00,  3.37it/s, MSE=0.0248]
02:05:35 - INFO: Sampling 8 new images....
499it [00:21, 22.82it/s]

<Figure size 640x480 with 8 Axes>
02:05:57 - INFO: Starting epoch 81:
100%|██████████| 340/340 [02:13<00:00,  2.93it/s, MSE=0.0321]
02:08:11 - INFO: Starting epoch 82:
100%|██████████| 340/340 [02:15<00:00,  2.76it/s, MSE=0.0274]
02:10:26 - INFO: Starting epoch 83:
100%|██████████| 340/340 [02:16<00:00,  3.49it/s, MSE=0.0069]
02:12:42 - INFO: Starting epoch 84:
100%|██████████| 340/340 [02:13<00:00,  3.05it/s, MSE=0.0847]
02:14:56 - INFO: Starting epoch 85:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.0237]
02:17:09 - INFO: Starting epoch 86:
100%|██████████| 340/340 [02:13<00:00,  2.71it/s, MSE=0.0124]
02:19:23 - INFO: Starting epoch 87:
100%|██████████| 340/340 [02:14<00:00,  3.69it/s, MSE=0.0537]
02:21:37 - INFO: Starting epoch 88:
100%|██████████| 340/340 [02:13<00:00,  3.13it/s, MSE=0.0463]
02:23:51 - INFO: Starting epoch 89:
100%|██████████| 340/340 [02:13<00:00,  2.85it/s, MSE=0.0137]
02:26:04 - INFO: Starting epoch 90:
100%|██████████| 340/340 [02:12<00:00,  3.05it/s, MSE=0.0198]
02:28:17 - INFO: Starting epoch 91:
100%|██████████| 340/340 [02:12<00:00,  3.31it/s, MSE=0.0205] 
02:30:30 - INFO: Starting epoch 92:
100%|██████████| 340/340 [02:12<00:00,  2.79it/s, MSE=0.0146]
02:32:43 - INFO: Starting epoch 93:
100%|██████████| 340/340 [02:12<00:00,  2.94it/s, MSE=0.00888]
02:34:56 - INFO: Starting epoch 94:
100%|██████████| 340/340 [02:12<00:00,  3.20it/s, MSE=0.0572]
02:37:08 - INFO: Starting epoch 95:
100%|██████████| 340/340 [02:13<00:00,  3.11it/s, MSE=0.021] 
02:39:22 - INFO: Starting epoch 96:
100%|██████████| 340/340 [02:13<00:00,  3.24it/s, MSE=0.0392]
02:41:35 - INFO: Starting epoch 97:
100%|██████████| 340/340 [02:12<00:00,  2.66it/s, MSE=0.0166]
02:43:48 - INFO: Starting epoch 98:
100%|██████████| 340/340 [02:14<00:00,  2.51it/s, MSE=0.0591]
02:46:03 - INFO: Starting epoch 99:
100%|██████████| 340/340 [02:16<00:00,  3.14it/s, MSE=0.0283]
02:48:19 - INFO: Starting epoch 100:
100%|██████████| 340/340 [02:13<00:00,  3.19it/s, MSE=0.0276]
02:50:33 - INFO: Sampling 8 new images....
499it [00:21, 23.23it/s]

<Figure size 640x480 with 8 Axes>
02:50:55 - INFO: Starting epoch 101:
100%|██████████| 340/340 [02:14<00:00,  3.48it/s, MSE=0.0293] 
02:53:10 - INFO: Starting epoch 102:
100%|██████████| 340/340 [02:16<00:00,  3.12it/s, MSE=0.0518]
02:55:27 - INFO: Starting epoch 103:
100%|██████████| 340/340 [02:14<00:00,  3.46it/s, MSE=0.0133]
02:57:42 - INFO: Starting epoch 104:
100%|██████████| 340/340 [02:15<00:00,  3.32it/s, MSE=0.0207] 
02:59:58 - INFO: Starting epoch 105:
100%|██████████| 340/340 [02:14<00:00,  3.26it/s, MSE=0.00727]
03:02:12 - INFO: Starting epoch 106:
100%|██████████| 340/340 [02:15<00:00,  3.81it/s, MSE=0.0319]
03:04:28 - INFO: Starting epoch 107:
100%|██████████| 340/340 [02:15<00:00,  3.11it/s, MSE=0.0348]
03:06:44 - INFO: Starting epoch 108:
100%|██████████| 340/340 [02:15<00:00,  3.34it/s, MSE=0.0245]
03:08:59 - INFO: Starting epoch 109:
100%|██████████| 340/340 [02:15<00:00,  3.24it/s, MSE=0.0139]
03:11:14 - INFO: Starting epoch 110:
100%|██████████| 340/340 [02:15<00:00,  3.23it/s, MSE=0.0311]
03:13:29 - INFO: Starting epoch 111:
100%|██████████| 340/340 [02:15<00:00,  3.53it/s, MSE=0.0234]
03:15:45 - INFO: Starting epoch 112:
100%|██████████| 340/340 [02:16<00:00,  3.13it/s, MSE=0.0158]
03:18:01 - INFO: Starting epoch 113:
100%|██████████| 340/340 [02:15<00:00,  3.44it/s, MSE=0.0315]
03:20:17 - INFO: Starting epoch 114:
100%|██████████| 340/340 [02:13<00:00,  3.16it/s, MSE=0.0187]
03:22:30 - INFO: Starting epoch 115:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.0228]
03:24:43 - INFO: Starting epoch 116:
100%|██████████| 340/340 [02:14<00:00,  3.04it/s, MSE=0.0607]
03:26:57 - INFO: Starting epoch 117:
100%|██████████| 340/340 [02:13<00:00,  3.34it/s, MSE=0.0217]
03:29:10 - INFO: Starting epoch 118:
100%|██████████| 340/340 [02:13<00:00,  3.28it/s, MSE=0.0131]
03:31:24 - INFO: Starting epoch 119:
100%|██████████| 340/340 [02:15<00:00,  3.54it/s, MSE=0.0618]
03:33:39 - INFO: Starting epoch 120:
100%|██████████| 340/340 [02:15<00:00,  3.08it/s, MSE=0.0388]
03:35:55 - INFO: Sampling 8 new images....
499it [00:21, 23.36it/s]

<Figure size 640x480 with 8 Axes>
03:36:16 - INFO: Starting epoch 121:
100%|██████████| 340/340 [02:19<00:00,  3.14it/s, MSE=0.0142] 
03:38:36 - INFO: Starting epoch 122:
100%|██████████| 340/340 [02:19<00:00,  2.97it/s, MSE=0.0112]
03:40:56 - INFO: Starting epoch 123:
100%|██████████| 340/340 [02:19<00:00,  2.84it/s, MSE=0.0243]
03:43:15 - INFO: Starting epoch 124:
100%|██████████| 340/340 [02:19<00:00,  3.11it/s, MSE=0.0312]
03:45:35 - INFO: Starting epoch 125:
100%|██████████| 340/340 [02:19<00:00,  3.26it/s, MSE=0.0513] 
03:47:54 - INFO: Starting epoch 126:
100%|██████████| 340/340 [02:18<00:00,  3.10it/s, MSE=0.0254]
03:50:13 - INFO: Starting epoch 127:
100%|██████████| 340/340 [02:17<00:00,  3.18it/s, MSE=0.00965]
03:52:30 - INFO: Starting epoch 128:
100%|██████████| 340/340 [02:17<00:00,  3.35it/s, MSE=0.0183]
03:54:47 - INFO: Starting epoch 129:
100%|██████████| 340/340 [02:17<00:00,  3.36it/s, MSE=0.0158]
03:57:05 - INFO: Starting epoch 130:
100%|██████████| 340/340 [02:18<00:00,  3.29it/s, MSE=0.0326]
03:59:24 - INFO: Starting epoch 131:
100%|██████████| 340/340 [02:17<00:00,  3.18it/s, MSE=0.0224]
04:01:42 - INFO: Starting epoch 132:
100%|██████████| 340/340 [02:16<00:00,  3.11it/s, MSE=0.0367]
04:03:58 - INFO: Starting epoch 133:
100%|██████████| 340/340 [02:18<00:00,  2.95it/s, MSE=0.0231]
04:06:16 - INFO: Starting epoch 134:
100%|██████████| 340/340 [02:19<00:00,  3.34it/s, MSE=0.0195]
04:08:35 - INFO: Starting epoch 135:
100%|██████████| 340/340 [02:18<00:00,  3.30it/s, MSE=0.00914]
04:10:54 - INFO: Starting epoch 136:
100%|██████████| 340/340 [02:19<00:00,  2.76it/s, MSE=0.0355]
04:13:13 - INFO: Starting epoch 137:
100%|██████████| 340/340 [02:19<00:00,  3.14it/s, MSE=0.0365]
04:15:33 - INFO: Starting epoch 138:
100%|██████████| 340/340 [02:20<00:00,  3.38it/s, MSE=0.0182] 
04:17:53 - INFO: Starting epoch 139:
100%|██████████| 340/340 [02:18<00:00,  3.19it/s, MSE=0.057]  
04:20:11 - INFO: Starting epoch 140:
100%|██████████| 340/340 [02:16<00:00,  3.27it/s, MSE=0.0156] 
04:22:28 - INFO: Sampling 8 new images....
499it [00:21, 22.81it/s]

<Figure size 640x480 with 8 Axes>
04:22:51 - INFO: Starting epoch 141:
100%|██████████| 340/340 [02:17<00:00,  3.11it/s, MSE=0.0256]
04:25:09 - INFO: Starting epoch 142:
100%|██████████| 340/340 [02:16<00:00,  2.82it/s, MSE=0.0271]
04:27:26 - INFO: Starting epoch 143:
100%|██████████| 340/340 [02:16<00:00,  3.35it/s, MSE=0.041] 
04:29:42 - INFO: Starting epoch 144:
100%|██████████| 340/340 [02:16<00:00,  3.04it/s, MSE=0.0126] 
04:31:59 - INFO: Starting epoch 145:
100%|██████████| 340/340 [02:16<00:00,  3.38it/s, MSE=0.0186]
04:34:16 - INFO: Starting epoch 146:
100%|██████████| 340/340 [02:19<00:00,  3.21it/s, MSE=0.0195]
04:36:36 - INFO: Starting epoch 147:
100%|██████████| 340/340 [02:19<00:00,  2.58it/s, MSE=0.00809]
04:38:55 - INFO: Starting epoch 148:
100%|██████████| 340/340 [02:20<00:00,  3.04it/s, MSE=0.0113]
04:41:15 - INFO: Starting epoch 149:
100%|██████████| 340/340 [02:19<00:00,  3.17it/s, MSE=0.013] 

2.6、使用训练好的模型进行采样

我们可以加载训练时觉得不错的模型进行采样生成。这个项目仅作为演示,生成汽车可能并不具备特别的价值。但是最新的novelai已经可以生成超高水平的二次元绘画,所以通过这个项目帮助我们理解diffusion模型的底层原理,可以让未来接触更多改进版的diffusion模型更加轻松。

In [6]

import paddle

model = UNet()
model.set_state_dict(paddle.load("car_models/ddpm_uncond140.pdparams"))   # 加载模型文件
diffusion = Diffusion(img_size=64, device="cuda")

sampled_images = diffusion.sample(model, n=8)

# 采样图片
for i in range(8):
    img = sampled_images[i].transpose([1, 2, 0])
    img = np.array(img).astype("uint8")
    plt.subplot(2, 4,i+1)
    plt.imshow(img)
plt.show()
05:37:15 - INFO: Sampling 8 new images....
499it [00:22, 22.61it/s]

<Figure size 640x480 with 8 Axes>

3、条件生成(通过标签指导图片生成)

3.1、训练过程解析

同非条件生成一样,我们使用前向过程采样得到标签,训练时使用Unet网络结构,同时在模型的输入中嵌入时间步的编码。这类似于transformer模型中的位置编码,让模型更容易训练。 这里我们额外添加类别的标签编码,也作为模型的输入。其中cfg表示条件生成与非条件生成之间的比值,cfg越大,生成的图像中条件生成的比例就越大(生成图像=(1-alpha)* 条件生成+(alpha)* 非条件生成),其中alpha与cfg相关。

  • ——cfg, classifier free guidance(标签引导)

另一方面,下面这个训练使用了上一代模型与当前模型参数的指数平均,削减因为离群点对模型参数更新的影响,从而实现更稳定的梯度更新。

  • ——ema, exponential moving average(指数移动平均)

运行下面代码前先重启内核!清空显存占用。

3.2、解压数据集

我们使用花朵数据集,包含5种种类,这样后面我们在采样时就可以指定其中一种种类进行生成。

In [1]

# 解压花朵数据集
import os
if not os.path.exists("work/flowers"):
    !mkdir work/flowers
!unzip -oq data/data173680/flowers.zip -d work/flowers

In [2]

# 加载数据集
"""由于条件生成需要同时提供图片标签,因此我们这里自定义数据集"""

# 1、将图片数据写入txt文件。flowers本来是分类数据集,这里我们把他的训练集和验证集都提取出来,当作我们生成模型的训练集。
import os
train_sunflower = os.listdir("work/flowers/pic/train/sunflower")            # 0——向日葵
valid_sunflower = os.listdir("work/flowers/pic/validation/sunflower")       # 0——向日葵
train_rose      = os.listdir("work/flowers/pic/train/rose")                 # 1——玫瑰
valid_rose      = os.listdir("work/flowers/pic/validation/rose")            # 1——玫瑰
train_tulip     = os.listdir("work/flowers/pic/train/tulip")                # 2——郁金香
valid_tulip     = os.listdir("work/flowers/pic/validation/tulip")           # 2——郁金香
train_dandelion = os.listdir("work/flowers/pic/train/dandelion")            # 3——蒲公英
valid_dandelion = os.listdir("work/flowers/pic/validation/dandelion")       # 3——蒲公英
train_daisy     = os.listdir("work/flowers/pic/train/daisy")                # 4——雏菊
valid_daisy     = os.listdir("work/flowers/pic/validation/daisy")           # 4——雏菊

with open("flowers_data.txt", 'w') as f:
    for image in train_sunflower:
        f.write("work/flowers/pic/train/sunflower/" + image + ";" + "0" + "\n")
    for image in valid_sunflower:
        f.write("work/flowers/pic/validation/sunflower/" + image + ";" + "0" + "\n")
    for image in train_rose:
        f.write("work/flowers/pic/train/rose/" + image + ";" + "1" + "\n")
    for image in valid_rose:
        f.write("work/flowers/pic/validation/rose/" + image + ";" + "1" + "\n")
    for image in train_tulip:
        f.write("work/flowers/pic/train/tulip/" + image + ";" + "2" + "\n")
    for image in valid_tulip:
        f.write("work/flowers/pic/validation/tulip/" + image + ";" + "2" + "\n")
    for image in train_dandelion:
        f.write("work/flowers/pic/train/dandelion/" + image + ";" + "3" + "\n")
    for image in valid_dandelion:
        f.write("work/flowers/pic/validation/dandelion/" + image + ";" + "3" + "\n")
    for image in train_daisy:
        f.write("work/flowers/pic/train/daisy/" + image + ";" + "4" + "\n")
    for image in valid_daisy:
        f.write("work/flowers/pic/validation/daisy/" + image + ";" + "4" + "\n")

3.3、构建数据集

因为这里我们的数据迭代器需要同时返回图片及标签。所以我们使用基础api构建我们的数据集。

In [3]

# 2、构建数据集
# 数据变化,返回图片与标签
import paddle.vision as V
from PIL import Image
from paddle.io import Dataset, DataLoader
from tqdm import tqdm

# 数据变换
transforms = V.transforms.Compose([
        V.transforms.Resize(80),  # args.image_size + 1/4 *args.image_size
        V.transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),
        V.transforms.ToTensor(),
        V.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

class TrainDataFlowers(Dataset):
    def __init__(self, txt_path="flowers_data.txt"):
        with open(txt_path, "r") as f:
            data = f.readlines()
        self.image_paths = data[:-1]    # 最后一行是空行,舍弃
    
    def __getitem__(self, index):
        image_path, label = self.image_paths[index].split(";")
        image = Image.open(image_path)
        image = transforms(image)

        label = int(label)
        
        return image, label
    
    def __len__(self):
        return len(self.image_paths)

dataset = TrainDataFlowers()
dataloader = DataLoader(dataset, batch_size=24, shuffle=True)

if __name__ == "__main__": # 测试数据集是否可用
    pbar = tqdm(dataloader)
    for i, (images, labels) in enumerate(pbar):
        pass
    print("ok")
  0%|          | 0/181 [00:00<?, ?it/s]W1023 15:49:27.184664  3398 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1023 15:49:27.188580  3398 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
100%|██████████| 181/181 [00:15<00:00, 11.37it/s]
ok
 

3.4、训练流程

训练中我们可以修改ARGS类的参数进行超参数定义。基本上,只要知道我们的损失函数是两张图片之间的均方误差,代码部分会变得比较简单。对比GAN而言,diffusion的参数更加容易调整,也更容易训练。

In [4]

import os
import paddle
import copy
import paddle.nn as nn
from matplotlib import pyplot as plt
%matplotlib inline
from tqdm import tqdm
from paddle import optimizer
from modules import UNet_conditional, EMA
import logging
import numpy as np
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

class Diffusion:
    def __init__(self, noise_steps=500, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule()
        self.alpha = 1. - self.beta
        self.alpha_hat = paddle.cumprod(self.alpha, dim=0)

        self.img_size = img_size
        self.device = device

    def prepare_noise_schedule(self):
        return paddle.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = paddle.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = paddle.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = paddle.randn(shape=x.shape)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return paddle.randint(low=1, high=self.noise_steps, shape=(n,))

    def sample(self, model, n, labels, cfg_scale=3):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with paddle.no_grad():
            x = paddle.randn((n, 3, self.img_size, self.img_size))
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = paddle.to_tensor([i] * x.shape[0]).astype("int64")
                predicted_noise = model(x, t, labels)
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    cfg_scale = paddle.to_tensor(cfg_scale).astype("float32")
                    predicted_noise = paddle.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = paddle.randn(shape=x.shape)
                else:
                    noise = paddle.zeros_like(x)
                x = 1 / paddle.sqrt(alpha) * (x - ((1 - alpha) / (paddle.sqrt(1 - alpha_hat))) * predicted_noise) + paddle.sqrt(beta) * noise
        model.train()
        x = (x.clip(-1, 1) + 1) / 2
        x = (x * 255)
        return x


def train(args):
    # setup_logging(args.run_name)
    device = args.device
    dataloader = args.dataloader
    model = UNet_conditional(num_classes=args.num_classes)
    opt = optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    l = len(dataloader)
    ema = EMA(0.995)
    ema_model = copy.deepcopy(model)
    ema_model.eval()
    # print("ema_model", ema_model)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, labels) in enumerate(pbar):
            t = diffusion.sample_timesteps(images.shape[0])
            x_t, noise = diffusion.noise_images(images, t)
            if np.random.random() < 0.1:
                labels = None
            predicted_noise = model(x_t, t, labels)
            loss = mse(noise, predicted_noise)  # 损失函数

            opt.clear_grad()
            loss.backward()
            opt.step()

            ema.step_ema(ema_model, model)
            pbar.set_postfix(MSE=loss.item())
            # logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        if epoch % 30 == 0:     # 保存模型,可视化训练结果。
            paddle.save(model.state_dict(), f"models/ddpm_cond{epoch}.pdparams")

            labels = paddle.arange(5).astype("int64")
            # 一共采样10张图片
            # 从左到右依次为-->向日葵,玫瑰,郁金香,蒲公英,雏菊
            sampled_images1 = diffusion.sample(model, n=len(labels), labels=labels)
            sampled_images2 = diffusion.sample(model, n=len(labels), labels=labels)
            # ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)
            for i in range(5):
                img = sampled_images1[i].transpose([1, 2, 0])
                img = np.array(img).astype("uint8")
                plt.subplot(2,5,i+1)
                plt.imshow(img)
            for i in range(5):
                img = sampled_images2[i].transpose([1, 2, 0])
                img = np.array(img).astype("uint8")
                plt.subplot(2,5,i+1+5)
                plt.imshow(img)
            plt.show()


def launch():
    import argparse

    # 参数设置
    class ARGS:
        def __init__(self):
            self.run_name = "DDPM_Uncondtional"
            self.epochs = 300
            self.batch_size = 48
            self.image_size = 64
            self.device = "cuda"
            self.lr = 1.5e-4
            self.num_classes = 5
            self.dataloader = dataloader


    args = ARGS()
    train(args)


if __name__ == '__main__':
    # 训练
    launch()
    pass
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
03:56:58 - INFO: Starting epoch 0:
100%|██████████| 181/181 [01:04<00:00,  3.76it/s, MSE=0.172]
03:58:03 - INFO: Sampling 5 new images....
499it [00:44, 11.13it/s]
03:58:48 - INFO: Sampling 5 new images....
499it [00:44, 11.28it/s]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data

<Figure size 640x480 with 10 Axes>
03:59:33 - INFO: Starting epoch 1:
100%|██████████| 181/181 [01:02<00:00,  3.81it/s, MSE=0.104]
04:00:36 - INFO: Starting epoch 2:
100%|██████████| 181/181 [01:02<00:00,  3.78it/s, MSE=0.103] 
04:01:38 - INFO: Starting epoch 3:
100%|██████████| 181/181 [01:02<00:00,  3.75it/s, MSE=0.0912]
04:02:41 - INFO: Starting epoch 4:
100%|██████████| 181/181 [01:02<00:00,  3.80it/s, MSE=0.0649]
04:03:43 - INFO: Starting epoch 5:
100%|██████████| 181/181 [00:59<00:00,  4.66it/s, MSE=0.0631]
04:04:43 - INFO: Starting epoch 6:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.179] 
04:05:38 - INFO: Starting epoch 7:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.0908]
04:06:33 - INFO: Starting epoch 8:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.158] 
04:07:29 - INFO: Starting epoch 9:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.171] 
04:08:24 - INFO: Starting epoch 10:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0362]
04:09:20 - INFO: Starting epoch 11:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0444]
04:10:16 - INFO: Starting epoch 12:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.0393]
04:11:12 - INFO: Starting epoch 13:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.064] 
04:12:07 - INFO: Starting epoch 14:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.035] 
04:13:03 - INFO: Starting epoch 15:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.063] 
04:13:58 - INFO: Starting epoch 16:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0157]
04:14:54 - INFO: Starting epoch 17:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0159]
04:15:49 - INFO: Starting epoch 18:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0212]
04:16:45 - INFO: Starting epoch 19:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0252]
04:17:40 - INFO: Starting epoch 20:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0192]
04:18:35 - INFO: Starting epoch 21:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0361]
04:19:31 - INFO: Starting epoch 22:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0177]
04:20:26 - INFO: Starting epoch 23:
100%|██████████| 181/181 [00:55<00:00,  4.54it/s, MSE=0.0527]
04:21:22 - INFO: Starting epoch 24:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.0458]
04:22:17 - INFO: Starting epoch 25:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0539]
04:23:13 - INFO: Starting epoch 26:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.205] 
04:24:09 - INFO: Starting epoch 27:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0463]
04:25:04 - INFO: Starting epoch 28:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.152] 
04:26:00 - INFO: Starting epoch 29:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.284] 
04:26:55 - INFO: Starting epoch 30:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0896]
04:27:51 - INFO: Sampling 5 new images....
499it [00:44, 11.27it/s]
04:28:36 - INFO: Sampling 5 new images....
499it [00:45, 11.01it/s]

<Figure size 640x480 with 10 Axes>
04:29:21 - INFO: Starting epoch 31:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.299] 
04:30:17 - INFO: Starting epoch 32:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0226]
04:31:12 - INFO: Starting epoch 33:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.00727]
04:32:08 - INFO: Starting epoch 34:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.132] 
04:33:03 - INFO: Starting epoch 35:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0498]
04:33:59 - INFO: Starting epoch 36:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0107]
04:34:55 - INFO: Starting epoch 37:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0116]
04:35:50 - INFO: Starting epoch 38:
100%|██████████| 181/181 [00:55<00:00,  4.56it/s, MSE=0.044] 
04:36:46 - INFO: Starting epoch 39:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.167] 
04:37:41 - INFO: Starting epoch 40:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0359]
04:38:37 - INFO: Starting epoch 41:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0064]
04:39:33 - INFO: Starting epoch 42:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0107]
04:40:28 - INFO: Starting epoch 43:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0216]
04:41:24 - INFO: Starting epoch 44:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0361]
04:42:20 - INFO: Starting epoch 45:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.0368]
04:43:15 - INFO: Starting epoch 46:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0283]
04:44:10 - INFO: Starting epoch 47:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.0352]
04:45:06 - INFO: Starting epoch 48:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0499]
04:46:01 - INFO: Starting epoch 49:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0359]
04:46:56 - INFO: Starting epoch 50:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0555]
04:47:52 - INFO: Starting epoch 51:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.14]  
04:48:47 - INFO: Starting epoch 52:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.0136]
04:49:42 - INFO: Starting epoch 53:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0242]
04:50:38 - INFO: Starting epoch 54:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0252]
04:51:33 - INFO: Starting epoch 55:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0274]
04:52:29 - INFO: Starting epoch 56:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0727]
04:53:24 - INFO: Starting epoch 57:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.023] 
04:54:20 - INFO: Starting epoch 58:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0457]
04:55:15 - INFO: Starting epoch 59:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0123]
04:56:10 - INFO: Starting epoch 60:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0125]
04:57:07 - INFO: Sampling 5 new images....
499it [00:45, 10.91it/s]
04:57:52 - INFO: Sampling 5 new images....
499it [00:45, 10.90it/s]

<Figure size 640x480 with 10 Axes>
04:58:39 - INFO: Starting epoch 61:
100%|██████████| 181/181 [00:56<00:00,  4.53it/s, MSE=0.00765]
04:59:35 - INFO: Starting epoch 62:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.00355]
05:00:30 - INFO: Starting epoch 63:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0256]
05:01:26 - INFO: Starting epoch 64:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0413]
05:02:22 - INFO: Starting epoch 65:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0146]
05:03:17 - INFO: Starting epoch 66:
100%|██████████| 181/181 [00:56<00:00,  4.57it/s, MSE=0.00737]
05:04:13 - INFO: Starting epoch 67:
100%|██████████| 181/181 [00:56<00:00,  4.63it/s, MSE=0.00363]
05:05:09 - INFO: Starting epoch 68:
100%|██████████| 181/181 [00:56<00:00,  4.58it/s, MSE=0.121] 
05:06:06 - INFO: Starting epoch 69:
100%|██████████| 181/181 [00:56<00:00,  4.53it/s, MSE=0.0124]
05:07:02 - INFO: Starting epoch 70:
100%|██████████| 181/181 [00:56<00:00,  4.53it/s, MSE=0.0235]
05:07:59 - INFO: Starting epoch 71:
100%|██████████| 181/181 [00:55<00:00,  4.52it/s, MSE=0.084] 
05:08:55 - INFO: Starting epoch 72:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.022] 
05:09:50 - INFO: Starting epoch 73:
100%|██████████| 181/181 [00:56<00:00,  4.61it/s, MSE=0.00922]
05:10:47 - INFO: Starting epoch 74:
100%|██████████| 181/181 [00:56<00:00,  4.27it/s, MSE=0.0059]
05:11:43 - INFO: Starting epoch 75:
100%|██████████| 181/181 [00:56<00:00,  4.60it/s, MSE=0.00901]
05:12:40 - INFO: Starting epoch 76:
100%|██████████| 181/181 [00:56<00:00,  4.60it/s, MSE=0.0261]
05:13:36 - INFO: Starting epoch 77:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.0317]
05:14:32 - INFO: Starting epoch 78:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0379]
05:15:27 - INFO: Starting epoch 79:
100%|██████████| 181/181 [00:54<00:00,  4.62it/s, MSE=0.0126]
05:16:22 - INFO: Starting epoch 80:
100%|██████████| 181/181 [00:55<00:00,  4.57it/s, MSE=0.0129]
05:17:17 - INFO: Starting epoch 81:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0174]
05:18:13 - INFO: Starting epoch 82:
100%|██████████| 181/181 [00:57<00:00,  4.59it/s, MSE=0.00267]
05:19:11 - INFO: Starting epoch 83:
100%|██████████| 181/181 [00:57<00:00,  4.61it/s, MSE=0.00863]
05:20:08 - INFO: Starting epoch 84:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.0928]
05:21:04 - INFO: Starting epoch 85:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0151]
05:21:59 - INFO: Starting epoch 86:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0231]
05:22:55 - INFO: Starting epoch 87:
100%|██████████| 181/181 [00:55<00:00,  4.50it/s, MSE=0.0442]
05:23:51 - INFO: Starting epoch 88:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.00999]
05:24:47 - INFO: Starting epoch 89:
100%|██████████| 181/181 [00:55<00:00,  4.57it/s, MSE=0.00467]
05:25:42 - INFO: Starting epoch 90:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0219]
05:26:38 - INFO: Sampling 5 new images....
499it [00:44, 11.12it/s]
05:27:23 - INFO: Sampling 5 new images....
499it [00:45, 11.06it/s]

<Figure size 640x480 with 10 Axes>
05:28:09 - INFO: Starting epoch 91:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.00285]
05:29:05 - INFO: Starting epoch 92:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.112] 
05:30:00 - INFO: Starting epoch 93:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0108]
05:30:56 - INFO: Starting epoch 94:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0281]
05:31:51 - INFO: Starting epoch 95:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0355]
05:32:47 - INFO: Starting epoch 96:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.133] 
05:33:42 - INFO: Starting epoch 97:
100%|██████████| 181/181 [00:56<00:00,  4.56it/s, MSE=0.0138]
05:34:39 - INFO: Starting epoch 98:
100%|██████████| 181/181 [00:56<00:00,  4.66it/s, MSE=0.00963]
05:35:35 - INFO: Starting epoch 99:
100%|██████████| 181/181 [00:56<00:00,  4.59it/s, MSE=0.0298]
05:36:31 - INFO: Starting epoch 100:
100%|██████████| 181/181 [00:56<00:00,  4.65it/s, MSE=0.00709]
05:37:27 - INFO: Starting epoch 101:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0737]
05:38:23 - INFO: Starting epoch 102:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0105]
05:39:18 - INFO: Starting epoch 103:
100%|██████████| 181/181 [00:56<00:00,  4.56it/s, MSE=0.00631]
05:40:14 - INFO: Starting epoch 104:
100%|██████████| 181/181 [00:55<00:00,  4.55it/s, MSE=0.00662]
05:41:10 - INFO: Starting epoch 105:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.262] 
05:42:05 - INFO: Starting epoch 106:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0206]
05:43:00 - INFO: Starting epoch 107:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.00979]
05:43:56 - INFO: Starting epoch 108:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0121]
05:44:52 - INFO: Starting epoch 109:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.00493]
05:45:48 - INFO: Starting epoch 110:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0158]
05:46:43 - INFO: Starting epoch 111:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.00567]
05:47:39 - INFO: Starting epoch 112:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.00994]
05:48:34 - INFO: Starting epoch 113:
100%|██████████| 181/181 [00:56<00:00,  4.57it/s, MSE=0.00712]
05:49:31 - INFO: Starting epoch 114:
100%|██████████| 181/181 [00:56<00:00,  4.61it/s, MSE=0.0414]
05:50:27 - INFO: Starting epoch 115:
100%|██████████| 181/181 [00:56<00:00,  4.61it/s, MSE=0.00445]
05:51:23 - INFO: Starting epoch 116:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0967]
05:52:19 - INFO: Starting epoch 117:
100%|██████████| 181/181 [00:55<00:00,  4.55it/s, MSE=0.0384]
05:53:15 - INFO: Starting epoch 118:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0122]
05:54:10 - INFO: Starting epoch 119:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0342]
05:55:06 - INFO: Starting epoch 120:
100%|██████████| 181/181 [00:56<00:00,  4.63it/s, MSE=0.0257]
05:56:02 - INFO: Sampling 5 new images....
499it [00:44, 11.29it/s]
05:56:46 - INFO: Sampling 5 new images....
499it [00:45, 11.03it/s]

<Figure size 640x480 with 10 Axes>
05:57:32 - INFO: Starting epoch 121:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.00285]
05:58:28 - INFO: Starting epoch 122:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0274]
05:59:24 - INFO: Starting epoch 123:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0629]
06:00:19 - INFO: Starting epoch 124:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0203]
06:01:15 - INFO: Starting epoch 125:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0619]
06:02:10 - INFO: Starting epoch 126:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0456]
06:03:05 - INFO: Starting epoch 127:
100%|██████████| 181/181 [00:55<00:00,  4.54it/s, MSE=0.0157]
06:04:01 - INFO: Starting epoch 128:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0136]
06:04:57 - INFO: Starting epoch 129:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.115] 
06:05:52 - INFO: Starting epoch 130:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0519]
06:06:48 - INFO: Starting epoch 131:
100%|██████████| 181/181 [00:56<00:00,  4.61it/s, MSE=0.0179]
06:07:44 - INFO: Starting epoch 132:
100%|██████████| 181/181 [00:56<00:00,  4.72it/s, MSE=0.0211]
06:08:40 - INFO: Starting epoch 133:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0172]
06:09:36 - INFO: Starting epoch 134:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.134] 
06:10:31 - INFO: Starting epoch 135:
100%|██████████| 181/181 [00:55<00:00,  4.56it/s, MSE=0.201] 
06:11:26 - INFO: Starting epoch 136:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0325]
06:12:22 - INFO: Starting epoch 137:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0203]
06:13:17 - INFO: Starting epoch 138:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.00265]
06:14:13 - INFO: Starting epoch 139:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.00424]
06:15:08 - INFO: Starting epoch 140:
100%|██████████| 181/181 [00:55<00:00,  4.73it/s, MSE=0.00383]
06:16:03 - INFO: Starting epoch 141:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.0153]
06:16:58 - INFO: Starting epoch 142:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.0284]
06:17:53 - INFO: Starting epoch 143:
100%|██████████| 181/181 [00:55<00:00,  4.55it/s, MSE=0.00366]
06:18:48 - INFO: Starting epoch 144:
100%|██████████| 181/181 [00:56<00:00,  4.52it/s, MSE=0.0912]
06:19:44 - INFO: Starting epoch 145:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.00704]
06:20:40 - INFO: Starting epoch 146:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.0042]
06:21:36 - INFO: Starting epoch 147:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.011] 
06:22:31 - INFO: Starting epoch 148:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.0379]
06:23:26 - INFO: Starting epoch 149:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.11]  
06:24:21 - INFO: Starting epoch 150:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.00667]
06:25:17 - INFO: Sampling 5 new images....
499it [00:43, 11.47it/s]
06:26:01 - INFO: Sampling 5 new images....
499it [00:42, 11.65it/s]

<Figure size 640x480 with 10 Axes>
06:26:44 - INFO: Starting epoch 151:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.00426]
06:27:39 - INFO: Starting epoch 152:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0859]
06:28:34 - INFO: Starting epoch 153:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.0238]
06:29:29 - INFO: Starting epoch 154:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0261]
06:30:24 - INFO: Starting epoch 155:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.049] 
06:31:19 - INFO: Starting epoch 156:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.00625]
06:32:14 - INFO: Starting epoch 157:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0107]
06:33:09 - INFO: Starting epoch 158:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.13]  
06:34:04 - INFO: Starting epoch 159:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0495]
06:34:59 - INFO: Starting epoch 160:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0112]
06:35:54 - INFO: Starting epoch 161:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.00525]
06:36:49 - INFO: Starting epoch 162:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.00437]
06:37:44 - INFO: Starting epoch 163:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00408]
06:38:39 - INFO: Starting epoch 164:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0177]
06:39:35 - INFO: Starting epoch 165:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.00417]
06:40:30 - INFO: Starting epoch 166:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0786]
06:41:25 - INFO: Starting epoch 167:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0205]
06:42:20 - INFO: Starting epoch 168:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0952]
06:43:15 - INFO: Starting epoch 169:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0118]
06:44:10 - INFO: Starting epoch 170:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.253] 
06:45:06 - INFO: Starting epoch 171:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.00373]
06:46:01 - INFO: Starting epoch 172:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.00618]
06:46:56 - INFO: Starting epoch 173:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0083]
06:47:50 - INFO: Starting epoch 174:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0171]
06:48:45 - INFO: Starting epoch 175:
100%|██████████| 181/181 [00:54<00:00,  4.60it/s, MSE=0.0216]
06:49:40 - INFO: Starting epoch 176:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.00168]
06:50:35 - INFO: Starting epoch 177:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0166]
06:51:30 - INFO: Starting epoch 178:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00656]
06:52:25 - INFO: Starting epoch 179:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.114] 
06:53:20 - INFO: Starting epoch 180:
100%|██████████| 181/181 [00:54<00:00,  4.70it/s, MSE=0.00226]
06:54:15 - INFO: Sampling 5 new images....
499it [00:42, 11.66it/s]
06:54:58 - INFO: Sampling 5 new images....
499it [00:43, 11.56it/s]

<Figure size 640x480 with 10 Axes>
06:55:42 - INFO: Starting epoch 181:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.0484]
06:56:37 - INFO: Starting epoch 182:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0196]
06:57:32 - INFO: Starting epoch 183:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.00695]
06:58:27 - INFO: Starting epoch 184:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0515]
06:59:21 - INFO: Starting epoch 185:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.00296]
07:00:16 - INFO: Starting epoch 186:
100%|██████████| 181/181 [00:54<00:00,  4.76it/s, MSE=0.0878]
07:01:11 - INFO: Starting epoch 187:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.0574]
07:02:06 - INFO: Starting epoch 188:
100%|██████████| 181/181 [00:54<00:00,  4.75it/s, MSE=0.00468]
07:03:00 - INFO: Starting epoch 189:
100%|██████████| 181/181 [00:54<00:00,  4.61it/s, MSE=0.0289]
07:03:55 - INFO: Starting epoch 190:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.0167]
07:04:50 - INFO: Starting epoch 191:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.0505]
07:05:45 - INFO: Starting epoch 192:
100%|██████████| 181/181 [00:54<00:00,  4.78it/s, MSE=0.00374]
07:06:39 - INFO: Starting epoch 193:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0176]
07:07:34 - INFO: Starting epoch 194:
100%|██████████| 181/181 [00:54<00:00,  4.76it/s, MSE=0.0161]
07:08:29 - INFO: Starting epoch 195:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.0161]
07:09:24 - INFO: Starting epoch 196:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.0358]
07:10:18 - INFO: Starting epoch 197:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0694]
07:11:13 - INFO: Starting epoch 198:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.107] 
07:12:09 - INFO: Starting epoch 199:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0383]
07:13:04 - INFO: Starting epoch 200:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0169]
07:13:59 - INFO: Starting epoch 201:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0855]
07:14:54 - INFO: Starting epoch 202:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.00749]
07:15:49 - INFO: Starting epoch 203:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.00324]
07:16:45 - INFO: Starting epoch 204:
100%|██████████| 181/181 [00:54<00:00,  4.74it/s, MSE=0.0965]
07:17:40 - INFO: Starting epoch 205:
100%|██████████| 181/181 [00:54<00:00,  4.70it/s, MSE=0.0277]
07:18:34 - INFO: Starting epoch 206:
100%|██████████| 181/181 [00:54<00:00,  4.71it/s, MSE=0.0146]
07:19:29 - INFO: Starting epoch 207:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.00659]
07:20:24 - INFO: Starting epoch 208:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.0176]
07:21:19 - INFO: Starting epoch 209:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.12]  
07:22:14 - INFO: Starting epoch 210:
100%|██████████| 181/181 [00:54<00:00,  4.76it/s, MSE=0.0688]
07:23:10 - INFO: Sampling 5 new images....
499it [00:43, 11.17it/s]
07:23:53 - INFO: Sampling 5 new images....
499it [00:43, 11.55it/s]

<Figure size 640x480 with 10 Axes>
07:24:37 - INFO: Starting epoch 211:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00553]
07:25:31 - INFO: Starting epoch 212:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0851]
07:26:26 - INFO: Starting epoch 213:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0147]
07:27:21 - INFO: Starting epoch 214:
100%|██████████| 181/181 [00:55<00:00,  4.75it/s, MSE=0.0669]
07:28:16 - INFO: Starting epoch 215:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.00531]
07:29:11 - INFO: Starting epoch 216:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.0315]
07:30:06 - INFO: Starting epoch 217:
100%|██████████| 181/181 [00:54<00:00,  4.76it/s, MSE=0.147] 
07:31:01 - INFO: Starting epoch 218:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.0547]
07:31:56 - INFO: Starting epoch 219:
100%|██████████| 181/181 [00:54<00:00,  4.74it/s, MSE=0.036] 
07:32:50 - INFO: Starting epoch 220:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.00479]
07:33:45 - INFO: Starting epoch 221:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0225]
07:34:40 - INFO: Starting epoch 222:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0192]
07:35:35 - INFO: Starting epoch 223:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.00701]
07:36:30 - INFO: Starting epoch 224:
100%|██████████| 181/181 [00:54<00:00,  4.56it/s, MSE=0.036] 
07:37:25 - INFO: Starting epoch 225:
100%|██████████| 181/181 [00:54<00:00,  4.79it/s, MSE=0.0908]
07:38:19 - INFO: Starting epoch 226:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.00345]
07:39:14 - INFO: Starting epoch 227:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.0657]
07:40:09 - INFO: Starting epoch 228:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0841]
07:41:04 - INFO: Starting epoch 229:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00407]
07:41:59 - INFO: Starting epoch 230:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.0249]
07:42:54 - INFO: Starting epoch 231:
100%|██████████| 181/181 [00:54<00:00,  4.71it/s, MSE=0.0563]
07:43:48 - INFO: Starting epoch 232:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.052] 
07:44:43 - INFO: Starting epoch 233:
100%|██████████| 181/181 [00:54<00:00,  4.74it/s, MSE=0.0698]
07:45:38 - INFO: Starting epoch 234:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.0553] 
07:46:33 - INFO: Starting epoch 235:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00646]
07:47:27 - INFO: Starting epoch 236:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.00258]
07:48:22 - INFO: Starting epoch 237:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0236]
07:49:17 - INFO: Starting epoch 238:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.0339]
07:50:12 - INFO: Starting epoch 239:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.00555]
07:51:07 - INFO: Starting epoch 240:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0571]
07:52:03 - INFO: Sampling 5 new images....
499it [00:42, 11.79it/s]
07:52:45 - INFO: Sampling 5 new images....
499it [00:42, 11.62it/s]

<Figure size 640x480 with 10 Axes>
07:53:28 - INFO: Starting epoch 241:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.00577]
07:54:24 - INFO: Starting epoch 242:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0155]
07:55:18 - INFO: Starting epoch 243:
100%|██████████| 181/181 [00:54<00:00,  4.64it/s, MSE=0.0322]
07:56:13 - INFO: Starting epoch 244:
100%|██████████| 181/181 [00:55<00:00,  4.76it/s, MSE=0.00787]
07:57:08 - INFO: Starting epoch 245:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.116] 
07:58:03 - INFO: Starting epoch 246:
100%|██████████| 181/181 [00:54<00:00,  4.70it/s, MSE=0.0187]
07:58:58 - INFO: Starting epoch 247:
100%|██████████| 181/181 [00:54<00:00,  4.70it/s, MSE=0.059] 
07:59:52 - INFO: Starting epoch 248:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0248]
08:00:48 - INFO: Starting epoch 249:
100%|██████████| 181/181 [00:54<00:00,  4.60it/s, MSE=0.0254]
08:01:42 - INFO: Starting epoch 250:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.133] 
08:02:38 - INFO: Starting epoch 251:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0752]
08:03:33 - INFO: Starting epoch 252:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.00802]
08:04:28 - INFO: Starting epoch 253:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.254] 
08:05:23 - INFO: Starting epoch 254:
100%|██████████| 181/181 [00:54<00:00,  4.60it/s, MSE=0.0261]
08:06:18 - INFO: Starting epoch 255:
100%|██████████| 181/181 [00:54<00:00,  4.62it/s, MSE=0.0514]
08:07:13 - INFO: Starting epoch 256:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.00751]
08:08:08 - INFO: Starting epoch 257:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0209]
08:09:03 - INFO: Starting epoch 258:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.0484]
08:09:58 - INFO: Starting epoch 259:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0255]
08:10:53 - INFO: Starting epoch 260:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.00507]
08:11:49 - INFO: Starting epoch 261:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0218]
08:12:45 - INFO: Starting epoch 262:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0203]
08:13:40 - INFO: Starting epoch 263:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.036] 
08:14:36 - INFO: Starting epoch 264:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0266]
08:15:32 - INFO: Starting epoch 265:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0145] 
08:16:27 - INFO: Starting epoch 266:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.00483]
08:17:23 - INFO: Starting epoch 267:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0604]
08:18:18 - INFO: Starting epoch 268:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0466]
08:19:14 - INFO: Starting epoch 269:
100%|██████████| 181/181 [00:56<00:00,  4.59it/s, MSE=0.00358]
08:20:10 - INFO: Starting epoch 270:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0104]
08:21:06 - INFO: Sampling 5 new images....
499it [00:45, 11.06it/s]
08:21:51 - INFO: Sampling 5 new images....
499it [00:44, 11.33it/s]

<Figure size 640x480 with 10 Axes>
08:22:36 - INFO: Starting epoch 271:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0111]
08:23:31 - INFO: Starting epoch 272:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0474]
08:24:26 - INFO: Starting epoch 273:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.106] 
08:25:22 - INFO: Starting epoch 274:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.00758]
08:26:18 - INFO: Starting epoch 275:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.00715]
08:27:13 - INFO: Starting epoch 276:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0412]
08:28:08 - INFO: Starting epoch 277:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.00941]
08:29:04 - INFO: Starting epoch 278:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0251]
08:29:59 - INFO: Starting epoch 279:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0385]
08:30:55 - INFO: Starting epoch 280:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0308]
08:31:50 - INFO: Starting epoch 281:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.108] 
08:32:45 - INFO: Starting epoch 282:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.049] 
08:33:41 - INFO: Starting epoch 283:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0046]
08:34:37 - INFO: Starting epoch 284:
100%|██████████| 181/181 [00:55<00:00,  4.56it/s, MSE=0.00371]
08:35:32 - INFO: Starting epoch 285:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.00421]
08:36:27 - INFO: Starting epoch 286:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00429]
08:37:22 - INFO: Starting epoch 287:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0258]
08:38:17 - INFO: Starting epoch 288:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0109]
08:39:13 - INFO: Starting epoch 289:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.152] 
08:40:08 - INFO: Starting epoch 290:
100%|██████████| 181/181 [00:55<00:00,  4.56it/s, MSE=0.0362]
08:41:04 - INFO: Starting epoch 291:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0161]
08:41:59 - INFO: Starting epoch 292:
100%|██████████| 181/181 [00:56<00:00,  4.52it/s, MSE=0.167] 
08:42:56 - INFO: Starting epoch 293:
100%|██████████| 181/181 [00:56<00:00,  4.46it/s, MSE=0.00373]
08:43:52 - INFO: Starting epoch 294:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0298]
08:44:48 - INFO: Starting epoch 295:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0306]
08:45:43 - INFO: Starting epoch 296:
100%|██████████| 181/181 [00:56<00:00,  4.52it/s, MSE=0.0111]
08:46:40 - INFO: Starting epoch 297:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0232]
08:47:36 - INFO: Starting epoch 298:
100%|██████████| 181/181 [00:56<00:00,  4.03it/s, MSE=0.0217]
08:48:32 - INFO: Starting epoch 299:
100%|██████████| 181/181 [00:56<00:00,  4.63it/s, MSE=0.05]  

3.5、使用训练好的模型来采样各种花朵

In [12]

import paddle
model = UNet_conditional(num_classes=5)
model.set_state_dict(paddle.load("models/ddpm_cond270.pdparams"))   # 加载模型文件
diffusion = Diffusion(img_size=64, device="cuda")

# 向日葵,玫瑰,郁金香,蒲公英,雏菊分别对应标签0,1,2,3,4
labels = paddle.to_tensor([0, 0, 0, 0, 0]).astype("int64")
# 标签引导强度
cfg_scale = 7
sampled_images = diffusion.sample(model, n=len(labels), labels=labels, cfg_scale=cfg_scale)
for i in range(5):
    img = sampled_images[i].transpose([1, 2, 0])
    img = np.array(img).astype("uint8")
    plt.subplot(1,5,i+1)
    plt.imshow(img)
plt.show()
09:01:17 - INFO: Sampling 5 new images....
499it [00:43, 11.60it/s]

<Figure size 640x480 with 5 Axes>

4、总结

  • 推理出了diffusion模型的损失函数,从最小化对数似然,到优化变分下界,简化变分下界,得到最后目标,预测噪声。

  • 提供了两版代码,其中条件生成与时下最火的text2image原理类似,只是text2image不仅仅使用单一类别作为编码。参考novelai。

  • 作为新一代生成模型,diffusion训练的过程可谓是十分的稳定,调参也比GAN相对简单不少!

  • 想要更好结果,我们只需要加大T,加大epoch即可。

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI周红伟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值