一、模型介绍
- CycleGAN
- 循环对抗生成网络,来自论文
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
。 - 作用:在没有配对示例的情况下学习将图像从源域
X
转换到目标域 Y
。
- 应用领域
- 与 Pix2Pix 的区别
- Pix2Pix 要求训练数据成对,CycleGAN 则不需要,更适用于现实中难以获取成对图像数据的情况。
二、模型结构
- 由两个镜像对称的 GAN 网络组成。
- 以苹果和橘子为例,
X
为苹果,Y
为橘子;G
为将苹果生成橘子风格的生成器,F
为将橘子生成苹果风格的生成器,D_X
和 D_Y
为其相应判别器。
- 关键部分为循环一致损失(Cycle Consistency Loss),确保从一个域转换再转换回来能回到初始状态。
三、数据集
名称 | 描述 |
---|
来源 | ImageNet |
内容 | 只使用了其中的苹果橘子部分 |
预处理 | 图像被统一缩放为 256×256 像素大小,进行随机裁剪、水平随机翻转和归一化,并转换为 MindRecord 格式 |
数量 | 训练集:苹果 996 张、橘子 1020 张;测试集:苹果 266 张、橘子 248 张 |
四、可视化
使用 matplotlib
模块对训练数据进行可视化。
五、构建生成器
函数 | 参数 | 作用 | 示例 |
---|
ConvNormReLU | input_channel (输入通道数)、out_planes (输出通道数)、kernel_size (卷积核大小)、stride (步长)、alpha (LeakyReLU 的斜率)、norm_mode (归一化模式)、pad_mode (填充模式)、use_relu (是否使用 ReLU)、padding (填充大小)、transpose (是否为转置卷积) | 进行卷积、归一化和激活操作 | conv_norm_relu = ConvNormReLU(3, 64, 4, 2, 0.2, 'instance', 'CONSTANT', True) |
ResidualBlock | dim (维度)、norm_mode (归一化模式)、dropout (是否使用 dropout)、pad_mode (填充模式) | 构建残差块 | residual_block = ResidualBlock(64, 'instance', False) |
ResNetGenerator | input_channel (输入通道数)、output_channel (初始输出通道数)、n_layers (残差块数量)、alpha (LeakyReLU 的斜率)、norm_mode (归一化模式)、dropout (是否使用 dropout)、pad_mode (填充模式) | 构建生成器网络 | net_rg = ResNetGenerator(3, 64, 9, 0.2, 'instance', False) |
六、构建判别器
函数 | 参数 | 作用 | 示例 |
---|
Discriminator | input_channel (输入通道数)、output_channel (初始输出通道数)、n_layers (卷积层数)、alpha (LeakyReLU 的斜率)、norm_mode (归一化模式) | 构建判别器网络 | net_d = Discriminator(3, 64, 3, 0.2, 'instance') |
七、优化器和损失函数
- 生成器和判别器分别使用单独的
Adam
优化器。 - 生成器的目标损失函数包括对抗损失和循环一致损失。
- 对抗损失:
L_{GAN}(G,D_Y,X,Y) = E_{y∼p_{data}(y)}[logD_Y(y)] + E_{x∼p_{data}(x)}[log(1 - D_Y(G(x)))]
- 循环一致损失:
L_{cyc}(G,F) = E_{x∼p_{data}(x)}[∥F(G(x)) - x∥_1] + E_{y∼p_{data}(y)}[∥G(F(y)) - y∥_1]
八、前向计算
函数 | 参数 | 作用 | 示例 |
---|
generator | img_a (源域图像)、img_b (目标域图像) | 进行图像生成和转换 | fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b) |
generator_forward | img_a (源域图像)、img_b (目标域图像) | 计算生成器的损失 | fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b = generator_forward(img_a, img_b) |
discriminator_forward | img_a (源域真实图像)、img_b (目标域真实图像)、fake_a (源域生成的假图像)、fake_b (目标域生成的假图像) | 计算判别器的损失 | loss_d = discriminator_forward(img_a, img_b, fake_a, fake_b) |
九、计算梯度和反向传播
函数 | 参数 | 作用 | 示例 |
---|
value_and_grad | 待求梯度的函数、梯度相对于哪些输入、待优化的参数 | 计算函数的梯度 | grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params()) |
train_step_g | img_a (源域图像)、img_b (目标域图像) | 计算生成器的梯度并反向传播更新参数 | fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = train_step_g(img_a, img_b) |
train_step_d | img_a (源域图像)、img_b (目标域图像)、fake_a (源域生成的假图像)、fake_b (目标域生成的假图像) | 计算判别器的梯度并反向传播更新参数 | loss_d = train_step_d(img_a, img_b, fake_a, fake_b) |
十、模型训练
- 分为训练判别器和生成器两部分。
- 判别器训练目的是提高判别图像真伪的概率。
- 生成器训练目的是产生更好的虚假图像。
- 训练过程中打印损失等信息,并定期保存模型参数。
十一、模型推理
函数 | 参数 | 作用 | 示例 |
---|
load_ckpt | net (网络)、ckpt_dir (模型参数文件路径) | 加载模型参数 | load_ckpt(net_rg_a, g_a_ckpt) |
eval_data | dir_path (图像目录路径)、net (网络)、a (偏移量) | 对图像进行推理并展示结果 | eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0) |