CycleGAN详解

前言: 想要实现照片风格转和油画风格互转吗?想要实现斑马野马互转吗?想要实现苹果橘子互转等这些任务吗?没错,CycleGAN网络就能够帮你满足这一目标!

details

CycleGAN详解

我们先来用一些简单的话语描述GAN。

所谓GAN,按照我的理解就是:
1.一个具有生成器和判别器的网络结构
2.生成器主要负责从随机样本空间高斯采样随机点,然后生成假图片,类似于造假货
3.而判别器是一个二分类的神经网络,主要负责判别喂给它的图像是来自真实世界的图像还是生成器生成的假图像,类似于博学的专家。
直到生成器生成出的图像,连博学的专家都分辨不出来了,那么就达到了所谓的纳什均衡了!

但是呢,GAN虽然很好,要是我如果想让普通照片和油画互转,阁下该怎么应对呢?
那么不得不提到我们今天的主角,CycleGAN

2.1 模型结构

11
如上图(a)所示,CycleGAN的结构是两个GAN组成的,也就是说,它有两副GAN的结构,即两个生成器和两个判别器,我们尝试做如下定义:

X 代表的是莫奈的油画 , Y 代表的是普通照片 D x 代表的是 Y → X 的判别器 , D y 代表的是 X → Y 的判别器 G 代表的是莫奈转普通照片的生成器,而 F 代表普通照片转莫奈油画 X代表的是莫奈的油画 , Y 代表的是普通照片 \\ D_x代表的是Y \to X的判别器,D_y代表的是X \to Y的判别器 \\ G代表的是莫奈转普通照片的生成器,而F代表普通照片转莫奈油画 X代表的是莫奈的油画,Y代表的是普通照片Dx代表的是YX的判别器,Dy代表的是XY的判别器G代表的是莫奈转普通照片的生成器,而F代表普通照片转莫奈油画

清楚了上面的定义,我们参照原始代码中生成器的结构:
在这里插入图片描述
在这里生成器是一个UNet形状的结构,输入一个256x256的图像,然后通过一系列块:

  • 其中CLI是由卷积、InstanceNorm和Leaky RELU组成的。
  • ReflectionPad(*)是一种图像增强方式,使得图像沿着边缘上下左右进行对称,增大图像分辨率的方式。
  • Residual block模块负责将数据进行恢复增强。

那么我们由如上的定义,可以写出如下的代码了!

# 分别定义两个判别器和两个生成器
# 生成器的定义
class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()

        channels = input_shape[0]

        # Initial convolution block
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]

        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Output layer
        model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

判别器特别简单,可以归纳为一个线性结构,直上直下,最后展平成一个(b,1)的维度。
在这里插入图片描述

所以我们再定义判别器:

# 判别器的定义
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

声明好结构之后,我们来定义生成器和判别器:

G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

2.2 一致性损失

难道这就是全部了吗?当然不是!因为你想想,如果这种方式直接训练,那么会出现这种问题:

莫奈油画转普通照片,只是让生成的图像很像普通照片的风格,但是一点儿莫奈油画里画的元素都没有!
也就是说,莫奈油画里画了一只狗(的油画),但是我生成的照片是一只其他的小动物(的照片),牛头不对马嘴了~

那该怎么办呢?所以作者提出了一致性损失!

简而言之,就是A -> B之后,B 再次转回到 A时,生成的图片要和A初始的图像对得上(也就是最小化损失,保持图片的一致性)

公式如下:
loss_cyc
上面的损失什么意思呢?就是:
x → G ( x ) : x 被随机采样,然后过 G ( ∗ ) 生成器,得到 G ( x ) G ( x ) → F ( G ( x ) ) : G ( x ) 送到 F ( ∗ ) 生成器 , 也就是上面说的再生成回来 最后得到结果和原始 x 的 L 1 损失 x \to G(x) : x被随机采样,然后过G(*)生成器,得到G(x) \\ G(x) \to F(G(x)) : G(x) 送到 F(*)生成器,也就是上面说的再生成回来 \\ 最后得到结果和原始x的L1 损失 xG(x):x被随机采样,然后过G()生成器,得到G(x)G(x)F(G(x)):G(x)送到F()生成器,也就是上面说的再生成回来最后得到结果和原始xL1损失
那么总损失可以概括为:
总损失
因此可以定义出文中的所有损失:

# GAN损失
criterion_GAN = torch.nn.MSELoss()

# A to B 循环损失
criterion_cycle = torch.nn.L1Loss()

# B to A 循环损失
criterion_identity = torch.nn.L1Loss()

最后是训练的代码:

    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            # A 、B都是真实数据
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            # 定义标签
            valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()

            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------

            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_GAN.item(),
                    loss_cycle.item(),
                    loss_identity.item(),
                    time_left,
                )
            )

            # If at sample interval save image
            if batches_done % opt.sample_interval == 0:
                sample_images(batches_done)

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

The End

参考资料

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值