昇思25天学习打卡营第11天 |昇思MindSpore CycleGAN 图像风格迁移学习

一、模型介绍

  1. CycleGAN
    • 循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
    • 作用:在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y
  2. 应用领域
    • 域迁移,通俗理解为图像风格迁移。
  3. 与 Pix2Pix 的区别
    • Pix2Pix 要求训练数据成对,CycleGAN 则不需要,更适用于现实中难以获取成对图像数据的情况。

二、模型结构

  1. 由两个镜像对称的 GAN 网络组成。
    • 以苹果和橘子为例,X 为苹果,Y 为橘子;G 为将苹果生成橘子风格的生成器,F 为将橘子生成苹果风格的生成器,D_XD_Y 为其相应判别器。
  2. 关键部分为循环一致损失(Cycle Consistency Loss),确保从一个域转换再转换回来能回到初始状态。

三、数据集

名称描述
来源ImageNet
内容只使用了其中的苹果橘子部分
预处理图像被统一缩放为 256×256 像素大小,进行随机裁剪、水平随机翻转和归一化,并转换为 MindRecord 格式
数量训练集:苹果 996 张、橘子 1020 张;测试集:苹果 266 张、橘子 248 张

四、可视化

使用 matplotlib 模块对训练数据进行可视化。

五、构建生成器

函数参数作用示例
ConvNormReLUinput_channel(输入通道数)、out_planes(输出通道数)、kernel_size(卷积核大小)、stride(步长)、alpha(LeakyReLU 的斜率)、norm_mode(归一化模式)、pad_mode(填充模式)、use_relu(是否使用 ReLU)、padding(填充大小)、transpose(是否为转置卷积)进行卷积、归一化和激活操作conv_norm_relu = ConvNormReLU(3, 64, 4, 2, 0.2, 'instance', 'CONSTANT', True)
ResidualBlockdim(维度)、norm_mode(归一化模式)、dropout(是否使用 dropout)、pad_mode(填充模式)构建残差块residual_block = ResidualBlock(64, 'instance', False)
ResNetGeneratorinput_channel(输入通道数)、output_channel(初始输出通道数)、n_layers(残差块数量)、alpha(LeakyReLU 的斜率)、norm_mode(归一化模式)、dropout(是否使用 dropout)、pad_mode(填充模式)构建生成器网络net_rg = ResNetGenerator(3, 64, 9, 0.2, 'instance', False)

六、构建判别器

函数参数作用示例
Discriminatorinput_channel(输入通道数)、output_channel(初始输出通道数)、n_layers(卷积层数)、alpha(LeakyReLU 的斜率)、norm_mode(归一化模式)构建判别器网络net_d = Discriminator(3, 64, 3, 0.2, 'instance')

七、优化器和损失函数

  1. 生成器和判别器分别使用单独的 Adam 优化器。
  2. 生成器的目标损失函数包括对抗损失和循环一致损失。
    • 对抗损失:L_{GAN}(G,D_Y,X,Y) = E_{y∼p_{data}(y)}[logD_Y(y)] + E_{x∼p_{data}(x)}[log(1 - D_Y(G(x)))]
    • 循环一致损失:L_{cyc}(G,F) = E_{x∼p_{data}(x)}[∥F(G(x)) - x∥_1] + E_{y∼p_{data}(y)}[∥G(F(y)) - y∥_1]

八、前向计算

函数参数作用示例
generatorimg_a(源域图像)、img_b(目标域图像)进行图像生成和转换fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
generator_forwardimg_a(源域图像)、img_b(目标域图像)计算生成器的损失fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b = generator_forward(img_a, img_b)
discriminator_forwardimg_a(源域真实图像)、img_b(目标域真实图像)、fake_a(源域生成的假图像)、fake_b(目标域生成的假图像)计算判别器的损失loss_d = discriminator_forward(img_a, img_b, fake_a, fake_b)

九、计算梯度和反向传播

函数参数作用示例
value_and_grad待求梯度的函数、梯度相对于哪些输入、待优化的参数计算函数的梯度grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
train_step_gimg_a(源域图像)、img_b(目标域图像)计算生成器的梯度并反向传播更新参数fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = train_step_g(img_a, img_b)
train_step_dimg_a(源域图像)、img_b(目标域图像)、fake_a(源域生成的假图像)、fake_b(目标域生成的假图像)计算判别器的梯度并反向传播更新参数loss_d = train_step_d(img_a, img_b, fake_a, fake_b)

十、模型训练

  1. 分为训练判别器和生成器两部分。
    • 判别器训练目的是提高判别图像真伪的概率。
    • 生成器训练目的是产生更好的虚假图像。
  2. 训练过程中打印损失等信息,并定期保存模型参数。

十一、模型推理

函数参数作用示例
load_ckptnet(网络)、ckpt_dir(模型参数文件路径)加载模型参数load_ckpt(net_rg_a, g_a_ckpt)
eval_datadir_path(图像目录路径)、net(网络)、a(偏移量)对图像进行推理并展示结果eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
  • 10
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值