昇思MindSpore学习总结十八 —— CycleGAN图像风格迁移互换

1、模型介绍

1.1 模型简介

        CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

        该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

1.2 模型结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):

        为了方便理解,这里以苹果和橘子为例介绍。上图中 𝑋𝑋可以理解为苹果,𝑌𝑌为橘子;𝐺𝐺为将苹果生成橘子风格的生成器,𝐹 为将橘子生成的苹果风格的生成器,𝐷𝑋 和 𝐷𝑌 为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

        该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

图中苹果图片 𝑥 经过生成器 𝐺𝐺 得到伪橘子 𝑌^,然后将伪橘子 𝑌^ 结果送进生成器 𝐹𝐹又产生苹果风格的结果 𝑥̂ ,最后将生成的苹果风格结果 𝑥̂ 与原苹果图片 𝑥 一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。

2、数据集

        本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

        这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理,为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,以省略大部分数据预处理的代码。

2.1 数据集下载

        使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。

from download import download  # 从download模块中导入download函数

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"
# 定义一个字符串变量url,存储要下载的文件的URL地址

download(url, ".", kind="zip", replace=True)
# 调用download函数,下载指定URL的文件
# 第一个参数url:文件的URL地址
# 第二个参数".":下载的文件将存储在当前目录(.)中
# 第三个参数kind="zip":指明下载的文件类型是zip压缩文件
# 第四个参数replace=True:如果目标目录中已经存在同名文件,将其替换

2.2 数据集加载

使用 MindSpore 的 MindDataset 接口读取和解析数据集。

from mindspore.dataset import MindDataset  
# 从mindspore.dataset模块中导入MindDataset类

# 读取MindRecord格式数据
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"  
# 定义一个字符串变量name_mr,存储MindRecord格式数据文件的路径

data = MindDataset(dataset_files=name_mr)  
# 创建一个MindDataset对象data,用于读取指定路径的MindRecord数据文件

print("Datasize: ", data.get_dataset_size())  
# 打印数据集的大小,即MindRecord文件中样本的数量

batch_size = 1  
# 定义批处理大小为1

dataset = data.batch(batch_size)  
# 将数据集分成批次,每批次包含batch_size个样本

datasize = dataset.get_dataset_size()  
# 获取批处理后的数据集大小(批次数量)

 2.3 可视化

        通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。

import numpy as np  
# 导入NumPy库,用于数组操作

import matplotlib.pyplot as plt  
# 导入Matplotlib库中的pyplot模块,用于绘图

mean = 0.5 * 255  
# 定义变量mean,表示图像标准化后的均值,取值为127.5

std = 0.5 * 255  
# 定义变量std,表示图像标准化后的标准差,取值为127.5

plt.figure(figsize=(12, 5), dpi=60)  
# 创建一个新的绘图窗口,设置图像大小为12x5英寸,分辨率为60DPI

for i, data in enumerate(dataset.create_dict_iterator()):  
    # 使用enumerate遍历数据集dataset,并生成一个字典迭代器,每个元素包含图像数据

    if i < 5:  
        # 只处理前5个样本

        show_images_a = data["image_A"].asnumpy()  
        # 获取字典中键为"image_A"的值,并转换为NumPy数组

        show_images_b = data["image_B"].asnumpy()  
        # 获取字典中键为"image_B"的值,并转换为NumPy数组

        plt.subplot(2, 5, i+1)  
        # 在2行5列的子图网格中选择第i+1个子图

        show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))  
        # 对图像进行反标准化处理,并转换数据类型为uint8(无符号8位整数),然后调整图像维度顺序为(height, width, channels)

        plt.imshow(show_images_a)  
        # 在子图中显示图像A

        plt.axis("off")  
        # 关闭子图的坐标轴

        plt.subplot(2, 5, i+6)  
        # 在2行5列的子图网格中选择第i+6个子图

        show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))  
        # 对图像进行反标准化处理,并转换数据类型为uint8(无符号8位整数),然后调整图像维度顺序为(height, width, channels)

        plt.imshow(show_images_b)  
        # 在子图中显示图像B

        plt.axis("off")  
        # 关闭子图的坐标轴

    else:  
        # 如果处理的样本数量达到5个,跳出循环
        break  

plt.show()  
# 显示所有子图

 3、构建生成器

        本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:

 具体的模型结构请参照下文代码:

import mindspore.nn as nn  
# 导入MindSpore的神经网络模块

import mindspore.ops as ops  
# 导入MindSpore的操作模块

from mindspore.common.initializer import Normal  
# 从MindSpore的common.initializer模块中导入Normal初始化器

weight_init = Normal(sigma=0.02)  
# 定义一个权重初始化器,标准差为0.02的正态分布

class ConvNormReLU(nn.Cell):  
    # 定义一个卷积、归一化和ReLU激活函数的组合层
    def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',
                 pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):
        super(ConvNormReLU, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if padding is None:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, padding=padding, weight_init=weight_init)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, weight_init=weight_init)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

# ConvNormReLU类的功能是构建一个包含卷积、归一化(批量或实例归一化)和ReLU激活函数的模块。根据传入的参数,可以配置不同的卷积模式(普通或转置卷积)、归一化模式(批量或实例归一化)以及是否使用ReLU激活函数。

class ResidualBlock(nn.Cell):  
    # 定义一个残差块
    def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
        self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
        self.dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(p=0.5)

    def construct(self, x):
        out = self.conv1(x)
        if self.dropout:
            out = self.dropout(out)
        out = self.conv2(out)
        return x + out

# ResidualBlock类的功能是构建一个残差块,该块包含两个ConvNormReLU模块(第二个模块不使用ReLU激活)。如果启用了dropout,还会在第一个卷积后应用dropout。最终输出是输入与第二个卷积结果的和。

class ResNetGenerator(nn.Cell):  
    # 定义一个ResNet生成器
    def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,
                 pad_mode="CONSTANT"):
        super(ResNetGenerator, self).__init__()
        self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
        self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)
        self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)
        layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
        self.residuals = nn.SequentialCell(layers)
        self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)
        self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)
        if pad_mode == "CONSTANT":
            self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',
                                      padding=3, weight_init=weight_init)
        else:
            pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
            conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)
            self.conv_out = nn.SequentialCell([pad, conv])

    def construct(self, x):
        x = self.conv_in(x)
        x = self.down_1(x)
        x = self.down_2(x)
        x = self.residuals(x)
        x = self.up_2(x)
        x = self.up_1(x)
        output = self.conv_out(x)
        return ops.tanh(output)

# ResNetGenerator类的功能是构建一个ResNet生成器模型。该模型包括输入卷积层、下采样层、多个残差块、上采样层和输出卷积层。输入的图像经过这些层的处理,生成一个新的图像。最终输出通过tanh函数进行非线性变换,以限制其值域在[-1, 1]之间。

# 实例化生成器
net_rg_a = ResNetGenerator()  
# 实例化ResNet生成器模型net_rg_a

net_rg_a.update_parameters_name('net_rg_a.')  
# 更新模型net_rg_a的参数名称前缀为'net_rg_a.'

net_rg_b = ResNetGenerator()  
# 实例化另一个ResNet生成器模型net_rg_b

net_rg_b.update_parameters_name('net_rg_b.')  
# 更新模型net_rg_b的参数名称前缀为'net_rg_b.'

 4、构建判别器

        判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2d 、 BatchNorm2d 和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

class Discriminator(nn.Cell):  
    # 定义一个判别器类
    def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init), 
            nn.LeakyReLU(alpha)
        ]
        # 初始化判别器的第一层卷积层,并设置LeakyReLU激活函数

        nf_mult = output_channel
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * output_channel
            layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        # 逐层添加卷积层、归一化和激活函数模块,循环添加n_layers层

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * output_channel
        layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        # 添加最后一层卷积层和归一化、激活函数模块

        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))
        # 添加输出卷积层,将输出通道数设置为1

        self.features = nn.SequentialCell(layers)
        # 将所有层组合成一个顺序执行的模块

    def construct(self, x):
        output = self.features(x)
        return output
        # 定义前向传播函数,输入x经过判别器网络后输出结果

# 判别器初始化
net_d_a = Discriminator()  
# 实例化判别器模型net_d_a

net_d_a.update_parameters_name('net_d_a.')  
# 更新模型net_d_a的参数名称前缀为'net_d_a.'

net_d_b = Discriminator()  
# 实例化另一个判别器模型net_d_b

net_d_b.update_parameters_name('net_d_b.')  
# 更新模型net_d_b的参数名称前缀为'net_d_b.'

5、优化器和损失函数

根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器 𝐺及其判别器 𝐷𝑌 ,目标损失函数定义为:

 其中 𝐺 试图生成看起来与 𝑌 中的图像相似的图像 𝐺(𝑥) ,而 𝐷𝑌 的目标是区分翻译样本 𝐺(𝑥) 和真实样本 𝑦 ,生成器的目标是最小化这个损失函数以此来对抗判别器。即 𝑚𝑖𝑛𝐺𝑚𝑎𝑥𝐷𝑌𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)。

单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于 𝑋 的每个图像 𝑥 ,图像转换周期应能够将 𝑥 带回原始图像,可以称之为正向循环一致性,即 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。对于 𝑌 ,类似的 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。可以理解采用了一个循环一致性损失来激励这种行为。

循环一致损失函数定义如下:

 循环一致损失能够保证重建图像 𝐹(𝐺(𝑥))𝐹(𝐺(𝑥)) 与输入图像 𝑥𝑥 紧密匹配。

# 构建生成器和判别器的优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 为生成器net_rg_a创建Adam优化器,学习率为0.0002,beta1为0.5

optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 为生成器net_rg_b创建Adam优化器,学习率为0.0002,beta1为0.5

optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 为判别器net_d_a创建Adam优化器,学习率为0.0002,beta1为0.5

optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 为判别器net_d_b创建Adam优化器,学习率为0.0002,beta1为0.5

# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
# 定义均方误差(MSE)损失函数,使用均值作为损失的聚合方式

l1_loss = nn.L1Loss("mean")
# 定义L1损失函数,使用均值作为损失的聚合方式

def gan_loss(predict, target):
    target = ops.ones_like(predict) * target
    # 根据预测值的形状创建一个与target值相同的张量
    loss = loss_fn(predict, target)
    # 计算预测值与目标值之间的MSE损失
    return loss
    # 返回损失值

# 该代码块的主要功能是为生成器和判别器定义优化器以及损失函数。Adam优化器用于更新模型参数,MSE损失函数用于计算生成对抗网络中的损失值。

6、前向计算

搭建模型前向计算损失的过程,过程如下代码。

        为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms  
# 导入MindSpore库

# 定义前向计算函数
def generator(img_a, img_b):  
    # 定义生成器的前向传播函数
    fake_a = net_rg_b(img_b)  
    # 用生成器net_rg_b将img_b转换为fake_a

    fake_b = net_rg_a(img_a)  
    # 用生成器net_rg_a将img_a转换为fake_b

    rec_a = net_rg_b(fake_b)  
    # 用生成器net_rg_b将fake_b转换回rec_a(重建的img_a)

    rec_b = net_rg_a(fake_a)  
    # 用生成器net_rg_a将fake_a转换回rec_b(重建的img_b)

    identity_a = net_rg_b(img_a)  
    # 用生成器net_rg_b生成img_a的身份映射identity_a

    identity_b = net_rg_a(img_b)  
    # 用生成器net_rg_a生成img_b的身份映射identity_b

    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b  
    # 返回生成的假图像、重建图像和身份映射

# 定义损失函数的权重
lambda_a = 10.0  
# 循环一致性损失的权重lambda_a

lambda_b = 10.0  
# 循环一致性损失的权重lambda_b

lambda_idt = 0.5  
# 身份映射损失的权重lambda_idt

def generator_forward(img_a, img_b):  
    # 定义生成器的前向传播计算
    true = Tensor(True, dtype=ms.bool_)  
    # 定义一个值为True的张量,用于表示真实标签

    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)  
    # 获取生成的假图像、重建图像和身份映射

    loss_g_a = gan_loss(net_d_b(fake_b), true)  
    # 计算生成的fake_b被判别器net_d_b认为是真的损失

    loss_g_b = gan_loss(net_d_a(fake_a), true)  
    # 计算生成的fake_a被判别器net_d_a认为是真的损失

    loss_c_a = l1_loss(rec_a, img_a) * lambda_a  
    # 计算img_a和rec_a之间的L1损失,并乘以权重lambda_a

    loss_c_b = l1_loss(rec_b, img_b) * lambda_b  
    # 计算img_b和rec_b之间的L1损失,并乘以权重lambda_b

    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt  
    # 计算img_a和identity_a之间的L1损失,并乘以权重lambda_a和lambda_idt

    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt  
    # 计算img_b和identity_b之间的L1损失,并乘以权重lambda_b和lambda_idt

    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b  
    # 总的生成器损失

    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b  
    # 返回生成的假图像和各种损失

def generator_forward_grad(img_a, img_b):  
    # 定义生成器的前向传播计算,用于求解梯度
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)  
    # 计算生成器的前向传播,获取总损失loss_g

    return loss_g  
    # 返回总损失

def discriminator_forward(img_a, img_b, fake_a, fake_b):  
    # 定义判别器的前向传播计算
    false = Tensor(False, dtype=ms.bool_)  
    # 定义一个值为False的张量,用于表示假标签

    true = Tensor(True, dtype=ms.bool_)  
    # 定义一个值为True的张量,用于表示真实标签

    d_fake_a = net_d_a(fake_a)  
    # 判别器net_d_a对假图像fake_a的预测

    d_img_a = net_d_a(img_a)  
    # 判别器net_d_a对真实图像img_a的预测

    d_fake_b = net_d_b(fake_b)  
    # 判别器net_d_b对假图像fake_b的预测

    d_img_b = net_d_b(img_b)  
    # 判别器net_d_b对真实图像img_b的预测

    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)  
    # 判别器net_d_a的损失,包括对假图像预测为假的损失和对真实图像预测为真的损失

    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)  
    # 判别器net_d_b的损失,包括对假图像预测为假的损失和对真实图像预测为真的损失

    loss_d = (loss_d_a + loss_d_b) * 0.5  
    # 总的判别器损失,取两部分损失的平均值

    return loss_d  
    # 返回判别器损失

def discriminator_forward_a(img_a, fake_a):  
    # 定义判别器net_d_a的前向传播计算
    false = Tensor(False, dtype=ms.bool_)  
    # 定义一个值为False的张量,用于表示假标签

    true = Tensor(True, dtype=ms.bool_)  
    # 定义一个值为True的张量,用于表示真实标签

    d_fake_a = net_d_a(fake_a)  
    # 判别器net_d_a对假图像fake_a的预测

    d_img_a = net_d_a(img_a)  
    # 判别器net_d_a对真实图像img_a的预测

    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)  
    # 判别器net_d_a的损失,包括对假图像预测为假的损失和对真实图像预测为真的损失

    return loss_d_a  
    # 返回判别器net_d_a的损失

def discriminator_forward_b(img_b, fake_b):  
    # 定义判别器net_d_b的前向传播计算
    false = Tensor(False, dtype=ms.bool_)  
    # 定义一个值为False的张量,用于表示假标签

    true = Tensor(True, dtype=ms.bool_)  
    # 定义一个值为True的张量,用于表示真实标签

    d_fake_b = net_d_b(fake_b)  
    # 判别器net_d_b对假图像fake_b的预测

    d_img_b = net_d_b(img_b)  
    # 判别器net_d_b对真实图像img_b的预测

    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)  
    # 判别器net_d_b的损失,包括对假图像预测为假的损失和对真实图像预测为真的损失

    return loss_d_b  
    # 返回判别器net_d_b的损失

# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50  
# 定义图像池的大小

def image_pool(images):  
    # 定义图像池函数,用于存储之前创建的图像
    num_imgs = 0  
    # 记录当前图像池中的图像数量

    image1 = []  
    # 初始化图像池

    if isinstance(images, Tensor):
        images = images.asnumpy()  
    # 如果输入是Tensor类型,将其转换为NumPy数组

    return_images = []  
    # 初始化返回的图像列表

    for image in images:  
        # 遍历输入的图像
        if num_imgs < pool_size:  
            # 如果图像池未满
            num_imgs = num_imgs + 1  
            # 增加图像计数

            image1.append(image)  
            # 将图像添加到图像池中

            return_images.append(image)  
            # 将图像添加到返回列表中

        else:  
            # 如果图像池已满
            if random.uniform(0, 1) > 0.5:  
                # 以50%的概率选择随机替换图像
                random_id = random.randint(0, pool_size - 1)  
                # 生成一个随机索引

                tmp = image1[random_id].copy()  
                # 复制随机选择的图像

                image1[random_id] = image  
                # 用新图像替换池中的图像

                return_images.append(tmp)  
                # 将被替换的图像添加到返回列表中

            else:  
                # 否则
                return

7、计算梯度和反向传播

其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

from mindspore import value_and_grad  
# 从MindSpore导入value_and_grad函数,用于计算值和梯度

# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())  
# 为生成器net_rg_a创建计算梯度的方法

grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())  
# 为生成器net_rg_b创建计算梯度的方法

grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())  
# 为判别器net_d_a创建计算梯度的方法

grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())  
# 为判别器net_d_b创建计算梯度的方法

# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):  
    # 定义生成器的训练步骤
    net_d_a.set_grad(False)  
    # 禁止判别器net_d_a的梯度计算

    net_d_b.set_grad(False)  
    # 禁止判别器net_d_b的梯度计算

    fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)  
    # 前向计算生成器,获取假图像和各种损失

    _, grads_g_a = grad_g_a(img_a, img_b)  
    # 计算生成器net_rg_a的梯度

    _, grads_g_b = grad_g_b(img_a, img_b)  
    # 计算生成器net_rg_b的梯度

    optimizer_rg_a(grads_g_a)  
    # 更新生成器net_rg_a的参数

    optimizer_rg_b(grads_g_b)  
    # 更新生成器net_rg_b的参数

    return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib  
    # 返回生成的假图像和各种损失

# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):  
    # 定义判别器的训练步骤
    net_d_a.set_grad(True)  
    # 允许判别器net_d_a的梯度计算

    net_d_b.set_grad(True)  
    # 允许判别器net_d_b的梯度计算

    loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)  
    # 计算判别器net_d_a的损失和梯度

    loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)  
    # 计算判别器net_d_b的损失和梯度

    loss_d = (loss_d_a + loss_d_b) * 0.5  
    # 总的判别器损失,取两部分损失的平均值

    optimizer_d_a(grads_d_a)  
    # 更新判别器net_d_a的参数

    optimizer_d_b(grads_d_b)  
    # 更新判别器net_d_b的参数

    return loss_d  
    # 返回判别器损失

8、模型训练

        训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

  • 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)^2];

  • 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)^2]来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

import os  
# 导入os模块,用于创建目录和处理文件路径

import time  
# 导入time模块,用于计时

import random  
# 导入random模块,用于随机数生成

import numpy as np  
# 导入NumPy库,用于数组操作

from PIL import Image  
# 导入PIL库中的Image模块,用于图像处理

from mindspore import Tensor, save_checkpoint  
# 从MindSpore库中导入Tensor类和save_checkpoint函数

from mindspore import dtype  
# 从MindSpore库中导入dtype模块

# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1  
# 定义训练的轮数

save_step_num = 80  
# 每隔80步打印和保存一次训练信息

save_checkpoint_epochs = 1  
# 每隔1个epoch保存一次模型

save_ckpt_dir = './train_ckpt_outputs/'  
# 定义模型检查点保存的目录

print('Start training!')  
# 打印训练开始的消息

for epoch in range(epochs):  
    # 迭代训练的轮数
    g_loss = []  
    # 初始化生成器损失列表

    d_loss = []  
    # 初始化判别器损失列表

    start_time_e = time.time()  
    # 记录epoch开始的时间

    for step, data in enumerate(dataset.create_dict_iterator()):  
        # 迭代数据集的每个批次
        start_time_s = time.time()  
        # 记录步骤开始的时间

        img_a = data["image_A"]  
        # 获取批次中的图像A

        img_b = data["image_B"]  
        # 获取批次中的图像B

        res_g = train_step_g(img_a, img_b)  
        # 执行生成器的训练步骤,返回假图像和各种损失

        fake_a = res_g[0]  
        # 获取生成的假图像A

        fake_b = res_g[1]  
        # 获取生成的假图像B

        res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))  
        # 执行判别器的训练步骤,返回判别器损失

        loss_d = float(res_d.asnumpy())  
        # 将判别器损失转换为浮点数

        step_time = time.time() - start_time_s  
        # 计算每个步骤的时间

        res = []  
        # 初始化结果列表

        for item in res_g[2:]:  
            # 迭代生成器的损失结果
            res.append(float(item.asnumpy()))  
            # 将损失值转换为浮点数并添加到结果列表中

        g_loss.append(res[0])  
        # 将生成器损失添加到生成器损失列表中

        d_loss.append(loss_d)  
        # 将判别器损失添加到判别器损失列表中

        if step % save_step_num == 0:  
            # 如果步数是save_step_num的倍数
            print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
                  f"step:[{int(step):>4d}/{int(datasize):>4d}], "
                  f"time:{step_time:>3f}s,\n"
                  f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "
                  f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "
                  f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "
                  f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")
            # 打印当前epoch和步骤的信息,包括时间和各种损失

    epoch_cost = time.time() - start_time_e  
    # 计算每个epoch的时间

    per_step_time = epoch_cost / datasize  
    # 计算每个步骤的平均时间

    mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize  
    # 计算生成器和判别器的平均损失

    print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
          f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "
          f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")
    # 打印每个epoch的信息,包括时间和平均损失

    if epoch % save_checkpoint_epochs == 0:  
        # 如果epoch是save_checkpoint_epochs的倍数
        os.makedirs(save_ckpt_dir, exist_ok=True)  
        # 创建保存检查点的目录

        save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))  
        # 保存生成器net_rg_a的检查点

        save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))  
        # 保存生成器net_rg_b的检查点

        save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))  
        # 保存判别器net_d_a的检查点

        save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))  
        # 保存判别器net_d_b的检查点

print('End of training!')  
# 打印训练结束的消息
  1. 导入必要的模块

    • 导入了多个模块和库,用于创建目录、处理图像、计时和计算等功能。
  2. 训练参数设置

    • epochs:定义训练的轮数,设置为1以节省时间。
    • save_step_num:定义每隔多少步保存一次模型和打印信息。
    • save_checkpoint_epochs:定义每隔多少个epoch保存一次检查点。
    • save_ckpt_dir:定义模型检查点保存的目录。
  3. 训练循环

    • for epoch in range(epochs):迭代每个epoch。
    • 在每个epoch中,初始化损失列表并记录开始时间。
    • for step, data in enumerate(dataset.create_dict_iterator()):迭代数据集的每个批次。
    • 获取批次中的图像A和图像B。
    • 执行生成器和判别器的训练步骤,计算相应的损失。
    • 每隔save_step_num步打印训练信息。
    • 计算并打印每个epoch的时间和平均损失。
    • 每隔save_checkpoint_epochs个epoch保存模型检查点。
  4. 保存模型检查点

    • 使用save_checkpoint函数保存生成器和判别器的检查点。
  5. 训练结束消息

    • 打印训练结束的消息。

9、模型推理

        下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第一行为原图,第二行为对应生成的结果图。

%%time
# 记录代码运行时间

import os  
# 导入os模块,用于文件和目录操作

from PIL import Image  
# 导入PIL库中的Image模块,用于图像处理

import mindspore.dataset as ds  
# 导入MindSpore的dataset模块,用于创建和处理数据集

import mindspore.dataset.vision as vision  
# 导入MindSpore的vision模块,用于图像预处理操作

from mindspore import load_checkpoint, load_param_into_net  
# 从MindSpore库中导入load_checkpoint和load_param_into_net函数,用于加载模型权重

# 加载权重文件
def load_ckpt(net, ckpt_dir):  
    # 定义加载权重文件的函数
    param_GA = load_checkpoint(ckpt_dir)  
    # 从指定目录加载检查点文件

    load_param_into_net(net, param_GA)  
    # 将加载的参数加载到网络中

g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'  
# 定义生成器A的检查点文件路径

g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'  
# 定义生成器B的检查点文件路径

load_ckpt(net_rg_a, g_a_ckpt)  
# 加载生成器A的权重

load_ckpt(net_rg_b, g_b_ckpt)  
# 加载生成器B的权重

# 图片推理
fig = plt.figure(figsize=(11, 2.5), dpi=100)  
# 创建一个新的图形窗口,设置图像大小和分辨率

def eval_data(dir_path, net, a):  
    # 定义图像推理函数
    def read_img():  
        # 定义读取图像的生成器函数
        for dir in os.listdir(dir_path):  
            # 遍历目录中的所有文件
            path = os.path.join(dir_path, dir)  
            # 构建文件路径

            img = Image.open(path).convert('RGB')  
            # 打开图像并转换为RGB模式

            yield img, dir  
            # 生成图像和图像名称

    dataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])  
    # 创建一个GeneratorDataset,使用read_img生成器函数作为数据源

    trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]  
    # 定义图像预处理操作,包括调整大小、归一化和维度转换

    dataset = dataset.map(operations=trans, input_columns=["image"])  
    # 将预处理操作应用到数据集

    dataset = dataset.batch(1)  
    # 将数据集分批次处理,每批次包含1个样本

    for i, data in enumerate(dataset.create_dict_iterator()):  
        # 迭代数据集的每个批次
        img = data["image"]  
        # 获取图像数据

        fake = net(img)  
        # 使用网络生成假图像

        fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))  
        # 对生成的假图像进行反归一化和维度转换

        img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))  
        # 对原始图像进行反归一化和维度转换

        fig.add_subplot(2, 8, i+1+a)  
        # 在图形窗口中添加子图

        plt.axis("off")  
        # 关闭坐标轴

        plt.imshow(img.asnumpy())  
        # 显示原始图像

        fig.add_subplot(2, 8, i+9+a)  
        # 在图形窗口中添加子图

        plt.axis("off")  
        # 关闭坐标轴

        plt.imshow(fake.asnumpy())  
        # 显示生成的假图像

eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)  
# 对apple目录中的图像进行推理,并使用生成器A

eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)  
# 对orange目录中的图像进行推理,并使用生成器B

plt.show()  
# 显示所有图像

打卡

 

  • 8
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值