【万物皆可 GAN】CycleGAN 原理详解

概述

CycleGAN (Cycle Generative Adversarial Network) 即循环对抗生成网络. CycleGAN 可以帮助我们实现图像的互相转换. CycleGAN 不需要数据配对就能实现图像的转换.

在这里插入图片描述
从上图我们可以看到, 通过使用 CycleGAN 我们实现了马到斑马的转换.

CycleGAN 可以做什么

答: 万物皆可 GAN

图片转换

在这里插入图片描述
在这里插入图片描述

图片修复

在这里插入图片描述
在这里插入图片描述

换脸

在这里插入图片描述

在这里插入图片描述

CycleGAN 网络结构

CycleGAN 由左右两个 GAN 网络组成. G(AB) 负责把 A 类物体 (斑马) 转换成 B 类物体 (正常的马). G(BA) 负责把 B 类物体 (正常的马) 还原成 A 类物体 (斑马).

在这里插入图片描述
如果我们只有 G(AB) 一个网络, 生成器 (Generator) 就会偷懒, 用随意任何一匹马蒙混过关, 如图底部. 所以我们需要两个 GAN 网络, 通过循环约束生成器 (Generator).

在这里插入图片描述
如图, 完整的 CycleGAN 由上下两部分组成, 上下两部分的唯一区别在于输入. 一个输入是 A 类, 生成 B 类; 另一个输入是 B 类, 生成 A 类.

CycleGAN 损失函数

CycleGAN 的损失函数总共有 2 组, 每组 4 个, 总计 8 个. 如图:
在这里插入图片描述
其中:

  • D_A & D_B: 是判断器的损失
  • G_A & G_B: 是生成器的损失
  • cycle_A & cycle_B: 是原始图像和还原图像的损失, 即 A => B => A, 初始和和还原 A 的损失
  • idt_A & idt_B: 是映射损失, 即用真实的 B 当做输入, 查看生成器是否会原封不动的输出 (B => B?)

在这里插入图片描述

### CycleGAN 网络架构解析 #### 生成器架构 CycleGAN 的生成器采用了一种称为残差块(Residual Block)的设计来保持输入图像的结构信息[^1]。具体来说,生成器由以下几个部分组成: - **编码器**:负责提取输入图像中的特征表示。通常通过一系列下采样操作实现,这些操作可以是卷积层加上步幅或最大池化层。 - **变换模块**:此阶段利用多个残差块处理来自编码器的信息,在不改变空间维度的情况下修改特征向量的内容属性。残差连接有助于缓解梯度消失问题并促进更深层次模型的学习能力。 - **解码器**:将中间表征重建为目标域内的输出图片形式。它执行的是上采样的过程,逐步恢复原始分辨率的同时引入自适应实例归一化 (AdaIN)[^3] 来调整风格特性。 ```python import torch.nn as nn class ResnetBlock(nn.Module): def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False): super(ResnetBlock, self).__init__() conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 conv_block += [ nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type != 'zero': conv_block += [nn.ReflectionPad2d(1)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out ``` #### 判别器架构 判别器遵循 PatchGAN 设计理念,旨在判断局部区域而非整个图像的真实性。其核心在于多尺度感知机制,即能够捕捉不同大小的感受野上的模式差异,从而提高伪造样本检测精度。 #### 损失函数构成 除了传统的 GAN 对抗损失外,CycleGAN 还加入了循环一致性的约束条件以确保跨领域映射的一致性和稳定性。这意味着当数据点经两次连续转换后应回到初始状态附近,以此强化双向翻译的质量控制[^2]。 ```python criterionGAN = torch.nn.MSELoss() lambda_A = lambda_B = 10.0 def compute_loss(real_A, fake_B, rec_A, real_B, fake_A, rec_B): loss_G_A = criterionGAN(netD_A(fake_B).detach(), target_real) * lambda_A loss_D_A = (criterionGAN(netD_A(real_B), target_real) + criterionGAN(netD_A(fake_B.detach()), target_fake)) / 2 loss_G_B = criterionGAN(netD_B(fake_A).detach(), target_real) * lambda_B loss_D_B = (criterionGAN(netD_B(real_A), target_real) + criterionGAN(netD_B(fake_A.detach()), target_fake)) / 2 loss_cycle_A = criterionCycle(rec_A, real_A) * lambda_A loss_cycle_B = criterionCycle(rec_B, real_B) * lambda_B total_loss = (loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B) return { "total": total_loss, "gan_a": loss_G_A.item(), "gan_b": loss_G_B.item(), "cycle_a": loss_cycle_A.item(), "cycle_b": loss_cycle_B.item(), "discriminator_a": loss_D_A.item(), "discriminator_b": loss_D_B.item() } ```
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值