一、Pix2Pix 概述
- 简介
- Pix2Pix 是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的深度学习图像转换模型,由 Phillip Isola 等作者在 2017 年 CVPR 上提出。
- 作用:能够在给定条件的情况下,实现不同类型图像之间的转换。
- 应用
- 能够实现多种图像转换任务,如语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图等。
- 组成
- 包括生成器和判别器两个关键部分。
- 生成器用于生成与目标域相似的图像,判别器用于判断生成的图像是否真实。
二、基础原理
- cGAN 与传统 GAN 生成器的区别
- cGAN 生成器以输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,本质是从像素到另一个像素的映射。
- 传统 GAN 生成器基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成。
- 公式对比:
- cGAN:
y = G(x,z)
,其中x
来自于训练数据,z
为随机噪声。 - 传统 GAN:直接由随机噪声
z
生成“假”图像,不借助观测图像x
的任何信息。
- cGAN:
- Pix2Pix 中判别器的任务
- 判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。
- 在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,使得判别器刚好具有 50%的概率判断正确。
- cGAN 的目标
- 损失函数:
L_{cGAN}(G,D) = E_{(x,y)}[log(D(x,y))] + E_{(x,z)}[log(1 - D(x,G(x,z)))]
- 目标简化:
argmin_G max_D L_{cGAN}(G,D)
- 解释:D 想要尽最大努力去正确分类真实图像与“假”图像,G 则尽最大努力用生成的“假”图像
y
欺骗 D。
- 损失函数:
- GAN 的目标
L_{GAN}(G,D) = E_y[log(D(y))] + E_{(x,z)}[log(1 - D(x,z))]
- 对比:GAN 直接由随机噪声
z
生成“假”图像,不借助观测图像x
的信息。
- cGAN 与传统损失混合
- 假设 cGAN 与 L1 正则化混合使用。
- 正则化损失:
L_{L1}(G) = E_{(x,y,z)}[||y - G(x,z)||_1]
- 最终目标:
argmin_G max_D L_{cGAN}(G,D) + λL_{L1}(G)
三、准备环节
- 配置环境
- 本案例在 GPU、CPU 和 Ascend 平台的动静态模式都支持。
- 准备数据
- 下载指定的外墙(facades)数据集。
- 使用
download
库下载数据集。
- 使用
- 调用
Pix2PixDataset
和create_train_dataset
读取训练集。 - 数据展示:
- 调用
MindDataset
读取数据。 - 使用
create_dict_iterator
创建字典迭代器获取数据。 - 利用
matplotlib.pyplot
可视化部分训练数据。
- 调用
- 下载指定的外墙(facades)数据集。
四、创建网络
- 生成器 G 结构
- U-Net:
- 是德国 Freiburg 大学模式识别和图像处理组提出的一种全卷积结构。
- 分为压缩路径和扩张路径,压缩路径由卷积和降采样操作组成,扩张路径由卷积和上采样组成。
- 扩张的每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成,形成 U 形结构。
- 区别:与常见的编解码结构相比,U-Net 加入了 skip-connection,用于保留不同分辨率下像素级的细节信息。
UNetSkipConnectionBlock
:- 定义了 U-Net 中的基本块,包含卷积、降采样、上采样、归一化、激活等操作。
UNetGenerator
:- 基于 U-Net 的生成器,通过堆叠多个
UNetSkipConnectionBlock
构建而成。
- 基于 U-Net 的生成器,通过堆叠多个
- U-Net:
- 判别器 D 结构
ConvNormRelu
:- 定义了卷积、归一化和激活的组合操作。
Discriminator
:- 基于 PatchGAN 的判别器,可看做卷积。
- 生成的矩阵中的每个点代表原图的一小块区域(patch),通过矩阵中的各个值来判断原图中对应每个 Patch 的真假。
- 初始化
- 实例化 Pix2Pix 生成器和判别器。
- 使用不同的初始化方法对生成器和判别器的参数进行初始化。
五、训练
- 目的:
- 训练判别器:最大程度地提高判别图像真伪的概率。
- 训练生成器:希望能产生更好的虚假图像,以欺骗判别器。
- 流程:
- 定义损失函数:
BCEWithLogitsLoss
:用于计算二分类交叉熵损失。L1Loss
:用于计算 L1 损失。
- 定义优化器:
Adam
优化器,设置学习率等参数。 - 计算梯度:使用
value_and_grad
计算判别器和生成器的梯度。 - 训练步骤:
- 通过
train_step
函数进行,包括计算损失、梯度更新等操作。 - 在每个 epoch 中,迭代数据集进行训练,并打印损失等信息。
- 通过
- 定义损失函数:
六、推理
- 加载训练好的生成器模型参数。
- 使用
load_checkpoint
和load_param_into_net
函数加载参数。
- 使用
- 对数据进行推理并展示效果。
- 读取数据集,使用生成器进行推理。
- 利用
matplotlib.pyplot
展示输入图像和推理生成的图像。
七、调用的库及功能
download
:用于下载数据集。mindspore.dataset
:MindDataset
:用于读取和处理数据集。create_dict_iterator
:创建字典迭代器以方便遍历数据集。shuffle
:对数据集进行随机打乱。num_parallel_workers
:设置并行处理数据的工作线程数。
mindspore.nn
:- 定义各种神经网络层,如卷积层、归一化层、激活函数等。
- 定义损失函数,如
BCEWithLogitsLoss
、L1Loss
。 - 定义优化器,如
Adam
。 - 进行参数初始化。
mindspore.ops
:进行各种操作,如拼接、卷积等。matplotlib.pyplot
:用于数据可视化,展示图像等。
八、操作流程
- 准备数据:
- 下载数据集。
- 读取和处理数据集。
- 展示部分训练数据。
- 构建网络:
- 定义生成器的结构,包括 U-Net 的基本块和生成器类。
- 定义判别器的结构,包括卷积、归一化和激活的组合,以及判别器类。
- 实例化生成器和判别器。
- 初始化网络参数。
- 训练模型:
- 定义损失函数和优化器。
- 计算梯度。
- 进行训练,更新参数,打印损失信息。
- 进行推理:
- 加载训练好的生成器参数。
- 对新数据进行推理。
- 展示推理结果。