昇思25天学习打卡营第5天 | CycleGAN图像风格迁移互换

模型介绍

模型简介

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

模型结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):

CycleGAN

为了方便理解,这里以苹果和橘子为例介绍。上图中 𝑋

可以理解为苹果,𝑌 为橘子;𝐺 为将苹果生成橘子风格的生成器,𝐹 为将橘子生成的苹果风格的生成器,𝐷𝑋 和 𝐷𝑌

为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

Cycle Consistency Loss

图中苹果图片 𝑥

经过生成器 𝐺 得到伪橘子 𝑌̂ ,然后将伪橘子 𝑌̂  结果送进生成器 𝐹 又产生苹果风格的结果 𝑥̂ ,最后将生成的苹果风格结果 𝑥̂  与原苹果图片 𝑥

一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。

数据集

本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理,为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,以省略大部分数据预处理的代码。

数据集下载

使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。

from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"
download(url, ".", kind="zip", replace=True)

数据集加载

使用 MindSpore 的 MindDataset 接口读取和解析数据集。

from mindspore.dataset import MindDataset

​

# 读取MindRecord格式数据

name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"

data = MindDataset(dataset_files=name_mr)

print("Datasize: ", data.get_dataset_size())

​

batch_size = 1

dataset = data.batch(batch_size)

datasize = dataset.get_dataset_size()
 

构建生成器

本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:

CycleGAN Generator

具体的模型结构请参照下文代码:

import mindspore.nn as nn

import mindspore.ops as ops

from mindspore.common.initializer import Normal

​

weight_init = Normal(sigma=0.02)

​

class ConvNormReLU(nn.Cell):

    def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',

                 pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):

        super(ConvNormReLU, self).__init__()

        norm = nn.BatchNorm2d(out_planes)

        if norm_mode == 'instance':

            norm = nn.BatchNorm2d(out_planes, affine=False)

        has_bias = (norm_mode == 'instance')

        if padding is None:

            padding = (kernel_size - 1) // 2

        if pad_mode == 'CONSTANT':

            if transpose:

                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',

                                          has_bias=has_bias, weight_init=weight_init)

            else:

                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',

                                 has_bias=has_bias, padding=padding, weight_init=weight_init)

            layers = [conv, norm]

        else:

            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))

            pad = nn.Pad(paddings=paddings, mode=pad_mode)

            if transpose:

                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',

                                          has_bias=has_bias, weight_init=weight_init)

            else:

                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',

                                 has_bias=has_bias, weight_init=weight_init)

            layers = [pad, conv, norm]

        if use_relu:

            relu = nn.ReLU()

            if alpha > 0:

                relu = nn.LeakyReLU(alpha)

            layers.append(relu)

        self.features = nn.SequentialCell(layers)

​

    def construct(self, x):

        output = self.features(x)

        return output

​

​

class ResidualBlock(nn.Cell):

    def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):

        super(ResidualBlock, self).__init__()

        self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)

        self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)

        self.dropout = dropout

        if dropout:

            self.dropout = nn.Dropout(p=0.5)

​

    def construct(self, x):

        out = self.conv1(x)

        if self.dropout:

            out = self.dropout(out)

        out = self.conv2(out)

        return x + out

​

​

class ResNetGenerator(nn.Cell):

    def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,

                 pad_mode="CONSTANT"):

        super(ResNetGenerator, self).__init__()

        self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)

        self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)

        self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)

        layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers

        self.residuals = nn.SequentialCell(layers)

        self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)

        self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)

        if pad_mode == "CONSTANT":

            self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',

                                      padding=3, weight_init=weight_init)

        else:

            pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)

            conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)

            self.conv_out = nn.SequentialCell([pad, conv])

​

    def construct(self, x):

        x = self.conv_in(x)

        x = self.down_1(x)

        x = self.down_2(x)

        x = self.residuals(x)

        x = self.up_2(x)

        x = self.up_1(x)

        output = self.conv_out(x)

        return ops.tanh(output)

​

# 实例化生成器

net_rg_a = ResNetGenerator()

net_rg_a.update_parameters_name('net_rg_a.')

​

net_rg_b = ResNetGenerator()

net_rg_b.update_parameters_name('net_rg_b.')

构建判别器

判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2dBatchNorm2dLeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

# 定义判别器

class Discriminator(nn.Cell):

    def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):

        super(Discriminator, self).__init__()

        kernel_size = 4

        layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),

                  nn.LeakyReLU(alpha)]

        nf_mult = output_channel

        for i in range(1, n_layers):

            nf_mult_prev = nf_mult

            nf_mult = min(2 ** i, 8) * output_channel

            layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))

        nf_mult_prev = nf_mult

        nf_mult = min(2 ** n_layers, 8) * output_channel

        layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))

        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))

        self.features = nn.SequentialCell(layers)

​

    def construct(self, x):

        output = self.features(x)

        return output

​

# 判别器初始化

net_d_a = Discriminator()

net_d_a.update_parameters_name('net_d_a.')

​

net_d_b = Discriminator()

net_d_b.update_parameters_name('net_d_b.')

优化器和损失函数

根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器 𝐺及其判别器 𝐷𝑌,目标损失函数定义为:

𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)=𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[𝑙𝑜𝑔𝐷𝑌(𝑦)]+𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[𝑙𝑜𝑔(1−𝐷𝑌(𝐺(𝑥)))]

其中 𝐺

试图生成看起来与 𝑌 中的图像相似的图像 𝐺(𝑥) ,而 𝐷𝑌 的目标是区分翻译样本 𝐺(𝑥) 和真实样本 𝑦 ,生成器的目标是最小化这个损失函数以此来对抗判别器。即 𝑚𝑖𝑛𝐺𝑚𝑎𝑥𝐷𝑌𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)。单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于 𝑋 的每个图像 𝑥 ,图像转换周期应能够将 𝑥 带回原始图像,可以称之为正向循环一致性,即 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。对于 𝑌 ,类似的 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥。可以理解采用了一个循环一致性损失来激励这种行为。循环一致损失函数定义如下:

𝐿𝑐𝑦𝑐(𝐺,𝐹)=𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[‖𝐹(𝐺(𝑥))−𝑥‖1]+𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[‖𝐺(𝐹(𝑦))−𝑦‖1]

循环一致损失能够保证重建图像 𝐹(𝐺(𝑥))与输入图像 𝑥紧密匹配。

# 构建生成器,判别器优化器

optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)

optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

​

optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)

optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

​

# GAN网络损失函数,这里最后一层不使用sigmoid函数

loss_fn = nn.MSELoss(reduction='mean')

l1_loss = nn.L1Loss("mean")

​

def gan_loss(predict, target):

    target = ops.ones_like(predict) * target

    loss = loss_fn(predict, target)

    return loss

前向计算

搭建模型前向计算损失的过程,过程如下代码。

为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms

​

# 前向计算

​

def generator(img_a, img_b):

    fake_a = net_rg_b(img_b)

    fake_b = net_rg_a(img_a)

    rec_a = net_rg_b(fake_b)

    rec_b = net_rg_a(fake_a)

    identity_a = net_rg_b(img_a)

    identity_b = net_rg_a(img_b)

    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

​

lambda_a = 10.0

lambda_b = 10.0

lambda_idt = 0.5

​

def generator_forward(img_a, img_b):

    true = Tensor(True, dtype.bool_)

    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)

    loss_g_a = gan_loss(net_d_b(fake_b), true)

    loss_g_b = gan_loss(net_d_a(fake_a), true)

    loss_c_a = l1_loss(rec_a, img_a) * lambda_a

    loss_c_b = l1_loss(rec_b, img_b) * lambda_b

    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt

    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt

    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b

    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b

​

def generator_forward_grad(img_a, img_b):

    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)

    return loss_g

​

def discriminator_forward(img_a, img_b, fake_a, fake_b):

    false = Tensor(False, dtype.bool_)

    true = Tensor(True, dtype.bool_)

    d_fake_a = net_d_a(fake_a)

    d_img_a = net_d_a(img_a)

    d_fake_b = net_d_b(fake_b)

    d_img_b = net_d_b(img_b)

    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)

    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)

    loss_d = (loss_d_a + loss_d_b) * 0.5

    return loss_d

​

def discriminator_forward_a(img_a, fake_a):

    false = Tensor(False, dtype.bool_)

    true = Tensor(True, dtype.bool_)

    d_fake_a = net_d_a(fake_a)

    d_img_a = net_d_a(img_a)

    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)

    return loss_d_a

​

def discriminator_forward_b(img_b, fake_b):

    false = Tensor(False, dtype.bool_)

    true = Tensor(True, dtype.bool_)

    d_fake_b = net_d_b(fake_b)

    d_img_b = net_d_b(img_b)

    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)

    return loss_d_b

​

# 保留了一个图像缓冲区,用来存储之前创建的50个图像

pool_size = 50

def image_pool(images):

    num_imgs = 0

    image1 = []

    if isinstance(images, Tensor):

        images = images.asnumpy()

    return_images = []

    for image in images:

        if num_imgs < pool_size:

            num_imgs = num_imgs + 1

            image1.append(image)

            return_images.append(image)

        else:

            if random.uniform(0, 1) > 0.5:

                random_id = random.randint(0, pool_size - 1)

​

                tmp = image1[random_id].copy()

                image1[random_id] = image

                return_images.append(tmp)

​

            else:

                return_images.append(image)

    output = Tensor(return_images, ms.float32)

    if output.ndim != 4:

        raise ValueError("img should be 4d, but get shape {}".format(output.shape))

    return output

计算梯度和反向传播

其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

模型训练

训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

  • 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)2]

  • 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)2]

  • 来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

import os

import time

import random

import numpy as np

from PIL import Image

from mindspore import Tensor, save_checkpoint

from mindspore import dtype

​

# 由于时间原因,epochs设置为1,可根据需求进行调整

epochs = 1

save_step_num = 80

save_checkpoint_epochs = 1

save_ckpt_dir = './train_ckpt_outputs/'

​

print('Start training!')

​

for epoch in range(epochs):

    g_loss = []

    d_loss = []

    start_time_e = time.time()

    for step, data in enumerate(dataset.create_dict_iterator()):

        start_time_s = time.time()

        img_a = data["image_A"]

        img_b = data["image_B"]

        res_g = train_step_g(img_a, img_b)

        fake_a = res_g[0]

        fake_b = res_g[1]

​

        res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))

        loss_d = float(res_d.asnumpy())

        step_time = time.time() - start_time_s

​

        res = []

        for item in res_g[2:]:

            res.append(float(item.asnumpy()))

        g_loss.append(res[0])

        d_loss.append(loss_d)

​

        if step % save_step_num == 0:

            print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "

                  f"step:[{int(step):>4d}/{int(datasize):>4d}], "

                  f"time:{step_time:>3f}s,\n"

                  f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "

                  f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "

                  f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "

                  f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")

​

    epoch_cost = time.time() - start_time_e

    per_step_time = epoch_cost / datasize

    mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize

​

    print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "

          f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "

          f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")

​

    if epoch % save_checkpoint_epochs == 0:

        os.makedirs(save_ckpt_dir, exist_ok=True)

        save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))

        save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))

        save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))

        save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))

​

print('End of training!')
 

总结

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络相较于GAN模型和Pix2Pix 模型的训练方式,cycleGan基于对图像对的变换识别来训练相关参数。

  • 20
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值