昇思 25 天学习打卡营第 19 天 | CycleGAN图像风格迁移互换

今天是参加昇思学习打卡营的第19天,学习内容是CycleGAN图像风格迁移互换。

以下是关键点概要:

  • 模型介绍:CycleGAN由两篇论文提出,主要用于在没有成对样本的情况下实现图像风格迁移。它在图像风格迁移和域迁移领域有广泛应用。

  • 模型结构:CycleGAN由两个对称的生成对抗网络(GAN)组成,每个网络包含一个生成器和一个判别器。生成器负责将图像从一个风格转换到另一个风格,而判别器则区分真实图像和生成器生成的图像。

  • 损失函数:CycleGAN的关键之一是循环一致性损失(Cycle Consistency Loss),它确保了图像在经过风格转换来回后能够尽可能地回到原始状态。

  • 数据集:教程中使用的是ImageNet数据集中的苹果和橘子图片,图片被预处理并转换为MindRecord格式。

  • 数据集加载与预处理:使用MindSpore的MindDataset接口来读取数据,并进行了随机裁剪、翻转和归一化处理。

  • 构建生成器和判别器:生成器的结构参考了ResNet模型,而判别器则是一个二分类网络,使用了PatchGANs模型。

  • 优化器和损失函数:为生成器和判别器设置了不同的优化器,并定义了对抗损失和循环一致性损失。

  • 前向计算:实现了生成器和判别器的前向计算过程,包括生成假图像、计算损失等。

  • 模型训练:详细介绍了训练循环,包括生成器和判别器的梯度计算和参数更新。

  • 模型推理:展示了如何加载训练好的模型参数并进行图像风格迁移的推理过程。

模型介绍

模型简介

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

模型结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):

CycleGAN

为了方便理解,这里以苹果和橘子为例介绍。上图中 𝑋𝑋 可以理解为苹果,𝑌𝑌 为橘子;𝐺𝐺 为将苹果生成橘子风格的生成器,𝐹𝐹 为将橘子生成的苹果风格的生成器,𝐷𝑋𝐷𝑋 和 𝐷𝑌𝐷𝑌 为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

Cycle Consistency Loss

图中苹果图片 𝑥𝑥 经过生成器 𝐺𝐺 得到伪橘子 𝑌̂ 𝑌^,然后将伪橘子 𝑌̂ 𝑌^ 结果送进生成器 𝐹𝐹 又产生苹果风格的结果 𝑥̂ 𝑥^,最后将生成的苹果风格结果 𝑥̂ 𝑥^ 与原苹果图片 𝑥𝑥 一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。

数据集

本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理,为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,以省略大部分数据预处理的代码。

数据集下载

使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"

download(url, ".", kind="zip", replace=True)

# 从download模块导入download函数
# 这个download函数可能是一个自定义的函数,用于下载文件
from download import download

# 定义一个字符串变量url,存储要下载的文件的URL地址
# 这里的URL指向一个名为CycleGAN_apple2orange.zip的压缩文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"

# 调用download函数下载文件
# 下载的文件将保存在当前目录(由'.'表示)
# kind参数指定了文件类型为zip格式的压缩包
# replace参数设置为True,表示如果已存在同名文件,将被新下载的文件替换
download(url, ".", kind="zip", replace=True)

 

数据集加载

使用 MindSpore 的 MindDataset 接口读取和解析数据集。

# 从mindspore.dataset包中导入MindDataset类
from mindspore.dataset import MindDataset

# 定义MindRecord文件的路径
# 这里指定了训练数据集的MindRecord文件位置
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"

# 使用MindDataset类创建数据集对象
# 通过传递MindRecord文件的路径给dataset_files参数来初始化数据集
data = MindDataset(dataset_files=name_mr)

# 打印数据集的大小,即数据集中样本的总数
# 使用get_dataset_size()方法获取数据集的大小
print("Datasize: ", data.get_dataset_size())

# 设置批处理大小为1
# 这意味着每个批次将包含一个样本
batch_size = 1

# 对数据集应用批处理操作,创建一个新的数据集对象
# 这个对象将按照指定的批处理大小来组织数据
dataset = data.batch(batch_size)

# 获取并打印批处理后数据集的大小
# 这应该与原始数据集的大小相同,因为批处理不会改变数据的总数
datasize = dataset.get_dataset_size()

 

可视化

通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。

# 导入numpy库用于数学运算
import numpy as np
# 导入matplotlib.pyplot用于绘图
import matplotlib.pyplot as plt

# 定义均值和标准差,用于数据的标准化处理
# 这里的255是假定数据是8位整数格式,即像素值范围在0-255
mean = 0.5 * 255
std = 0.5 * 255

# 创建一个新的matplotlib图形,设置图形的大小和分辨率
plt.figure(figsize=(12, 5), dpi=60)

# 使用enumerate遍历数据集中的字典形式的迭代器
for i, data in enumerate(dataset.create_dict_iterator()):
    # 只处理前5个数据点
    if i < 5:
        # 从数据集中提取名为"image_A"和"image_B"的图像数据,并转换为numpy数组
        show_images_a = data["image_A"].asnumpy()
        show_images_b = data["image_B"].asnumpy()

        # 计算subplot的位置,2行5列的子图布局,i+1是子图的索引
        plt.subplot(2, 5, i+1)
        # 将图像数据标准化到[0, 255]范围并转换为uint8类型,然后转换通道顺序
        show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        # 显示图像,关闭坐标轴
        plt.imshow(show_images_a)
        plt.axis("off")

        # 计算subplot的位置,2行5列的子图布局,i+6是子图的索引
        plt.subplot(2, 5, i+6)
        # 同上,对第二组图像进行相同的处理
        show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        plt.imshow(show_images_b)
        plt.axis("off")
    # 如果已经处理了5个数据点,则退出循环
    else:
        break

# 显示所有子图组成的图形
plt.show()

构建生成器

本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:

CycleGAN Generator

# 从mindspore.nn包中导入nn模块
import mindspore.nn as nn
# 从mindspore.ops包中导入ops模块
import mindspore.ops as ops
# 从mindspore.common.initializer包中导入initializer模块
from mindspore.common.initializer import Normal

# 使用Normal初始化器初始化权重,标准差为0.02
weight_init = Normal(sigma=0.02)

# 定义一个卷积-归一化-ReLU激活函数的组合层
class ConvNormReLU(nn.Cell):
    def __init__(self, ...):
        # 初始化层的各种参数
        ...
        
        # 根据参数构建层的组件
        ...
        
    # construct方法定义了如何前向传播数据
    def construct(self, x):
        # 使用SequentialCell来顺序执行层中的操作
        output = self.features(x)
        return output

# 定义残差块
class ResidualBlock(nn.Cell):
    def __init__(self, ...):
        # 初始化残差块的参数
        ...
        
        # 构建残差块的卷积层
        ...
        
    # construct方法定义了残差块的前向传播
    def construct(self, x):
        # 执行卷积操作并加上输入x
        output = self.conv1(x)
        ...
        
        return x + out

# 定义基于ResNet的生成器网络
class ResNetGenerator(nn.Cell):
    def __init__(self, ...):
        # 初始化生成器网络的参数
        ...
        
        # 构建生成器网络的各层
        ...
        
    # construct方法定义了生成器网络的前向传播
    def construct(self, x):
        # 执行网络中的卷积、激活、上采样等操作
        ...
        
        output = self.conv_out(x)
        return ops.tanh(output)  # 使用tanh激活函数输出最终结果

# 实例化生成器网络
net_rg_a = ResNetGenerator()  # 创建生成器A
net_rg_a.update_parameters_name('net_rg_a.')  # 更新参数名称前缀

net_rg_b = ResNetGenerator()  # 创建生成器B
net_rg_b.update_parameters_name('net_rg_b.')  # 更新参数名称前缀

 

构建判别器

判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2d 、 BatchNorm2d 和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

# 定义判别器类,继承自nn.Cell
class Discriminator(nn.Cell):
    # 初始化方法,设置判别器的层和参数
    def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):
        super(Discriminator, self).__init__()  # 调用基类的初始化方法
        kernel_size = 4  # 设置卷积核大小
        layers = []  # 初始化一个列表,用于存储判别器的层
        # 添加第一个卷积层和LeakyReLU激活函数
        layers.append(
            nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),
            nn.LeakyReLU(alpha)
        )
        nf_mult = output_channel  # 设置当前的特征图的通道数
        # 循环添加卷积层和归一化层
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * output_channel  # 计算下一层的通道数
            layers.append(
                ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1)
            )
        # 添加最后一个卷积层,输出尺寸为1x1的特征图
        layers.append(
            ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1),
            nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init)
        )
        self.features = nn.SequentialCell(layers)  # 将所有层包装成SequentialCell

    # construct方法定义了判别器的前向传播过程
    def construct(self, x):
        output = self.features(x)  # 通过SequentialCell执行层的顺序操作
        return output

# 实例化判别器A
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')  # 设置参数名前缀,用于区分不同网络的参数

# 实例化判别器B
net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')  # 设置参数名前缀

优化器和损失函数

根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器 𝐺𝐺 及其判别器 𝐷𝑌𝐷𝑌 ,目标损失函数定义为:

𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)=𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[𝑙𝑜𝑔𝐷𝑌(𝑦)]+𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[𝑙𝑜𝑔(1−𝐷𝑌(𝐺(𝑥)))]𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)=𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[𝑙𝑜𝑔𝐷𝑌(𝑦)]+𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[𝑙𝑜𝑔(1−𝐷𝑌(𝐺(𝑥)))]

其中 𝐺𝐺 试图生成看起来与 𝑌𝑌 中的图像相似的图像 𝐺(𝑥)𝐺(𝑥) ,而 𝐷𝑌𝐷𝑌 的目标是区分翻译样本 𝐺(𝑥)𝐺(𝑥) 和真实样本 𝑦𝑦 ,生成器的目标是最小化这个损失函数以此来对抗判别器。即 𝑚𝑖𝑛𝐺𝑚𝑎𝑥𝐷𝑌𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)𝑚𝑖𝑛𝐺𝑚𝑎𝑥𝐷𝑌𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌) 。

单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于 𝑋𝑋 的每个图像 𝑥𝑥 ,图像转换周期应能够将 𝑥𝑥 带回原始图像,可以称之为正向循环一致性,即 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。对于 𝑌𝑌 ,类似的 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。可以理解采用了一个循环一致性损失来激励这种行为。

循环一致损失函数定义如下:

𝐿𝑐𝑦𝑐(𝐺,𝐹)=𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[‖𝐹(𝐺(𝑥))−𝑥‖1]+𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[‖𝐺(𝐹(𝑦))−𝑦‖1]𝐿𝑐𝑦𝑐(𝐺,𝐹)=𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[‖𝐹(𝐺(𝑥))−𝑥‖1]+𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[‖𝐺(𝐹(𝑦))−𝑦‖1]

循环一致损失能够保证重建图像 𝐹(𝐺(𝑥))𝐹(𝐺(𝑥)) 与输入图像 𝑥𝑥 紧密匹配。

# 构建生成器A的优化器
# 使用Adam优化算法,学习率为0.0002,beta1参数为0.5
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 构建生成器B的优化器,参数与生成器A相同
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

# 构建判别器A的优化器,参数与生成器相同
optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 构建判别器B的优化器,参数与判别器A相同
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

# 定义GAN网络的损失函数
# 使用均方误差损失函数(MSELoss),不使用sigmoid激活函数
loss_fn = nn.MSELoss(reduction='mean')  # reduction='mean'表示输出平均损失
# 定义L1损失函数
l1_loss = nn.L1Loss("mean")  # "mean"表示输出平均L1损失

# 定义GAN的损失函数
def gan_loss(predict, target):
    # 将目标转换为与预测相同的形状,值为1,表示判别器希望将真实样本判断为真
    target = ops.ones_like(predict) * target
    # 计算预测和目标之间的MSE损失
    loss = loss_fn(predict, target)
    return loss  # 返回损失值

前向计算

搭建模型前向计算损失的过程,过程如下代码。

为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms  # 导入MindSpore框架

# 定义生成器网络的前向计算函数
def generator(img_a, img_b):
    # 使用两个生成器网络生成假图像
    fake_a = net_rg_b(img_b)
    fake_b = net_rg_a(img_a)
    # 使用生成器网络进行循环一致性转换
    rec_a = net_rg_b(fake_b)
    rec_b = net_rg_a(fake_a)
    # 使用生成器网络生成同一风格的身份图像
    identity_a = net_rg_b(img_a)
    identity_b = net_rg_a(img_b)
    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

# 定义循环一致性损失和身份损失的权重
lambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5

# 定义生成器的前向传播和损失计算函数
def generator_forward(img_a, img_b):
    # 创建一个布尔张量,值为True,用于GAN损失函数
    true = ms.Tensor(True, dtype=ms.bool_)
    # 调用生成器网络进行前向计算
    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
    # 计算生成器的对抗损失
    loss_g_a = gan_loss(net_d_b(fake_b), true)
    loss_g_b = gan_loss(net_d_a(fake_a), true)
    # 计算循环一致性损失
    loss_c_a = l1_loss(rec_a, img_a) * lambda_a
    loss_c_b = l1_loss(rec_b, img_b) * lambda_b
    # 计算身份损失
    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt
    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt
    # 总损失是所有损失的和
    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b
    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b

# 定义生成器的梯度计算函数
def generator_forward_grad(img_a, img_b):
    # 不需要生成的图像,只返回损失值
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)
    return loss_g

# 定义判别器的前向传播和损失计算函数
def discriminator_forward(img_a, img_b, fake_a, fake_b):
    # 创建一个布尔张量,值为False,用于GAN损失函数
    false = ms.Tensor(False, dtype=ms.bool_)
    # 创建一个布尔张量,值为True,用于GAN损失函数
    true = ms.Tensor(True, dtype=ms.bool_)
    # 判别器对真实图像和假图像进行判断
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    # 计算判别器A的损失
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    # 计算判别器B的损失
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    # 总损失是两个判别器损失的平均
    loss_d = (loss_d_a + loss_d_b) * 0.5
    return loss_d

# 定义判别器A的前向传播和损失计算函数
def discriminator_forward_a(img_a, fake_a):
    # 与discriminator_forward相同,但只针对判别器A
    ...

# 定义判别器B的前向传播和损失计算函数
def discriminator_forward_b(img_b, fake_b):
    # 与discriminator_forward相同,但只针对判别器B
    ...

# 定义图像池函数,用于存储之前创建的图像
pool_size = 50  # 图像池的大小
def image_pool(images):
    # 如果传入的是Tensor,转换为numpy数组
    if isinstance(images, ms.Tensor):
        images = images.asnumpy()
    num_imgs = 0  # 当前图像池中的图像数量
    image1 = []  # 当前存储的图像列表
    return_images = []  # 返回的图像列表
    for image in images:
        # 如果图像数量小于池大小,直接添加
        if num_imgs < pool_size:
            num_imgs += 1
            image1.append(image)
            return_images.append(image)
        else:
            # 如果图像数量等于池大小,随机替换一个图像
            if ms.random.uniform(0, 1) > 0.5:
                random_id = ms.random.randint(0, pool_size - 1)
                tmp = image1[random_id].copy()
                image1[random_id] = image
                return_images.append(tmp)
            else:
                return_images.append(image)
    output = ms.Tensor(return_images, ms.float32)  # 将返回的图像列表转换为Tensor
    # 确保输出是四维的,即[batch_size, channels, height, width]
    if output.ndim != 4:
        raise ValueError("img should be 4d, but get shape {}".format(output.shape))
    return output

计算梯度和反向传播

其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

from mindspore import value_and_grad  # 导入MindSpore的value_and_grad函数

# 使用value_and_grad函数实例化求梯度的方法
# 这将用于后续的梯度计算和反向传播
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())

grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())

# 定义训练生成器的函数
def train_step_g(img_a, img_b):
    # 设置判别器的梯度为False,因为在生成器的训练中不需要更新判别器
    net_d_a.set_grad(False)
    net_d_b.set_grad(False)

    # 调用generator_forward函数进行前向计算
    # 获取生成器生成的假图像、循环一致性图像和损失值
    fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)

    # 调用value_and_grad函数计算生成器的梯度
    _, grads_g_a = grad_g_a(img_a, img_b)
    _, grads_g_b = grad_g_b(img_a, img_b)

    # 使用优化器更新生成器A和B的参数
    optimizer_rg_a(grads_g_a)
    optimizer_rg_b(grads_g_b)

    # 返回生成器的输出和损失值
    return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib

# 定义训练判别器的函数
def train_step_d(img_a, img_b, fake_a, fake_b):
    # 设置判别器的梯度为True,因为在判别器的训练中需要更新判别器
    net_d_a.set_grad(True)
    net_d_b.set_grad(True)

    # 调用value_and_grad函数计算判别器A的梯度
    loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)
    # 调用value_and_grad函数计算判别器B的梯度
    loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)

    # 计算判别器的总损失,是两个判别器损失的平均
    loss_d = (loss_d_a + loss_d_b) * 0.5

    # 使用优化器更新判别器A和B的参数
    optimizer_d_a(grads_d_a)
    optimizer_d_b(grads_d_b)

    # 返回判别器的损失值
    return loss_d

模型训练

训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

  • 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)2]𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)2] ;

  • 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)2]𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)2] 来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

import os  # 导入os模块,用于操作文件和目录
import time  # 导入time模块,用于测量时间
import random  # 导入random模块,用于生成随机数
import numpy as np  # 导入numpy模块,用于数学运算
from PIL import Image  # 导入PIL模块,用于图像处理
from mindspore import Tensor, save_checkpoint  # 导入MindSpore模块
from mindspore import dtype  # 导入MindSpore的数据类型模块

# 设置训练的epoch数,可以根据需要调整
epochs = 1
# 设置每个epoch中保存模型的步骤数
save_step_num = 80
# 设置每个epoch中保存模型的频率
save_checkpoint_epochs = 1
# 设置保存模型的目录
save_ckpt_dir = './train_ckpt_outputs/'

print('Start training!')  # 打印开始训练的信息

for epoch in range(epochs):  # 遍历每个epoch
    g_loss = []  # 初始化生成器损失列表
    d_loss = []  # 初始化判别器损失列表
    start_time_e = time.time()  # 记录epoch开始的时间
    for step, data in enumerate(dataset.create_dict_iterator()):  # 遍历数据集中的每个数据点
        start_time_s = time.time()  # 记录当前步骤开始的时间
        img_a = data["image_A"]  # 获取数据集中的image_A
        img_b = data["image_B"]  # 获取数据集中的image_B
        res_g = train_step_g(img_a, img_b)  # 训练生成器并获取结果
        fake_a = res_g[0]  # 获取生成器生成的假图像A
        fake_b = res_g[1]  # 获取生成器生成的假图像B

        res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))  # 训练判别器并获取结果
        loss_d = float(res_d.asnumpy())  # 将损失转换为float类型
        step_time = time.time() - start_time_s  # 计算当前步骤的执行时间

        res = []
        for item in res_g[2:]:  # 获取生成器的损失
            res.append(float(item.asnumpy()))
        g_loss.append(res[0])  # 将生成器的损失添加到列表
        d_loss.append(loss_d)  # 将判别器的损失添加到列表

        if step % save_step_num == 0:  # 每隔save_step_num步打印一次训练信息
            print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
                  f"step:[{int(step):>4d}/{int(datasize):>4d}], "
                  f"time:{step_time:>3f}s,\n"
                  f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "
                  f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "
                  f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "
                  f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")

    epoch_cost = time.time() - start_time_e  # 计算epoch的执行时间
    per_step_time = epoch_cost / datasize  # 计算每步的平均时间
    mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize  # 计算平均损失

    print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
          f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "
          f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")

    if epoch % save_checkpoint_epochs == 0:  # 每隔save_checkpoint_epochs个epoch保存一次模型
        os.makedirs(save_ckpt_dir, exist_ok=True)  # 创建保存模型的目录
        save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))  # 保存生成器A的模型
        save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))  # 保存生成器B的模型
        save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))  # 保存判别器A的模型
        save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))  # 保存判别器B的模型

print('End of training!')  # 打印训练结束的信息

模型推理

下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第一行为原图,第二行为对应生成的结果图。

import matplotlib.pyplot as plt
import numpy as np
from mindspore import Tensor

# 定义一个函数,用于加载和评估数据
def eval_data(dir_path, net, a):
    # 使用PIL库打开图像文件
    def read_img():
        for dir in os.listdir(dir_path):
            path = os.path.join(dir_path, dir)
            img = Image.open(path).convert('RGB')
            yield img, dir

    # 创建一个数据集对象,用于生成图像数据
    dataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])
    # 定义图像预处理操作:调整大小、归一化和通道顺序转换
    trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]
    dataset = dataset.map(operations=trans, input_columns=["image"])
    dataset = dataset.batch(1)
    
    # 创建一个matplotlib图形,用于显示图像
    fig = plt.figure(figsize=(11, 2.5), dpi=100)
    for i, data in enumerate(dataset.create_dict_iterator()):
        img = data["image"]  # 获取原始图像数据
        fake = net(img)  # 使用网络生成风格迁移后的图像

        # 对图像数据进行标准化和格式转换
        fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
        img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))

        # 添加原始图像到图形的子图
        fig.add_subplot(2, 8, i+1+a)
        plt.axis("off")  # 关闭坐标轴
        plt.imshow(img.asnumpy())  # 显示图像

        # 添加风格迁移后的图像到图形的子图
        fig.add_subplot(2, 8, i+9+a)
        plt.axis("off")  # 关闭坐标轴
        plt.imshow(fake.asnumpy())  # 显示图像

# 调用eval_data函数,评估不同数据集的图像
eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)

# 显示所有子图组成的图形
plt.show()

学习心得:

通过学习,我深刻理解了CycleGAN的核心原理,即通过对抗训练实现图像风格之间的转换,而无需成对的训练样本。这种无监督学习的方法极大地扩展了图像风格迁移的应用场景。

数据预处理是模型训练的基础。在CycleGAN中,图像的归一化、裁剪和增强等操作对于提高模型性能至关重要。我学会了如何使用MindSpore等工具进行有效的数据预处理。

CycleGAN由生成器和判别器组成,生成器负责生成风格迁移后的图像,判别器则用于区分真实图像和生成图像。我学习了如何设计这些网络结构,并通过实验验证了不同网络结构对模型性能的影响。

损失函数是训练过程中的指导信号。我了解到CycleGAN中循环一致性损失和对抗损失的重要性,它们共同作用于生成器和判别器,引导模型学习到更好的风格迁移特征。

在学习过程中,我掌握了如何使用MindSpore等深度学习框架进行模型训练,包括设置学习率、选择优化器和调整超参数。通过反复实验,我学会了如何调优模型以获得更好的性能。

模型评估是检验学习效果的关键。我学会了如何使用不同的指标(如均方误差、结构相似性指数等)来评估模型的生成效果。同时,我也学会了如何可视化生成的图像,直观地展示风格迁移的效果。

学习CycleGAN不仅是对技术的学习,更是对创新思维的培养。我相信通过不断实践和探索,能够将这些知识应用到更广泛的领域,解决更多的实际问题。

加油!!!

  • 9
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值