昇思25天学习打卡营第42天|生成式-CycleGAN图像风格迁移互换

昇思25天学习打卡营第42天|生成式-CycleGAN图像风格迁移互换

CycleGAN简介

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法,有效解决了数据集难以获取的问题,扩展性更好,应用范围更广。

网络结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,即包括两个生成器和判别器,进行双向生成。

  • 生成器G:负责将源域图像(例如夏天的风景)转换为目标域图像(例如冬天的风景)。
  • 生成器F:负责将目标域图像转换回源域图像,确保风格迁移的一致性。
  • 判别器D_X:负责区分源域图像和生成的目标域图像,以促使生成器生成更加逼真的图像。
  • 判别器D_Y:负责区分目标域图像生成的源域图像

训练过程中,生成器不仅要欺骗判别器使其认为生成的图像是现实的,还需要通过循环一致性损失来确保图像在转换后能够恢复到原始样貌。这种循环一致性损失确保了源图像经过目标域风格转换后再转换回源域时,可以接近原始图像,从而保留了图像的基本内容和结构。

损失函数

对抗损失

  • 生成器 G 的对抗损失:鼓励 G 生成逼真的目标域图像 G(X),使得 D_Y 无法区分 G(X)和真实的 Y。
  • 生成器 F 的对抗损失:鼓励 F 生成逼真的源域图像 F(Y),使得 D_X 无法区分 F(Y) 和真实的 X。

循环一致性损失

  • 源域图像的循环一致性损失:确保将源域图像 X 转换为目标域图像 G(X)后,再转换回源域图像 F(G(X)) 时,能够重建出原始图像 X。
  • 目标域图像的循环一致性损失:确保将目标域图像 Y 转换为源域图像 F(Y) 后,再转换回目标域图像 G(F(Y)) 时,能够重建出原始图像 Y。

CycleGAN与cGAN比较

输入的区别

  • cGAN:生成器和判别器都接收额外的条件信息 y。生成器的输入是 (z,y),判别器的输入是 (x,y) 。
  • CycleGAN:输入是无需配对的训练数据,使用来自两个不同域的图像集(例如,域A的图像和域B的图像)。

输出的区别

  • cGAN:生成器输出是符合条件 y 的样本,判别器输出是在条件 y下数据是否真实的概率。
  • CycleGAN:生成器生成风格转换后的图像,例如将域A的图像转换为域B的风格;判别器分别对两个域的图像进行区分,判定图像是否来自真实数据。

优点

  1. 不需要配对数据,直接使用来自两个域的未配对图像,这在实际应用中更为灵活。
  2. 通过循环一致性约束,能够更好地保留图像内容的一致性,即使在风格转换后也能保持图像的结构特征。

缺点

  1. 模型训练更加复杂,计算量大,因为需要同时训练两个生成器和两个判别器。

基于“图像池”技术的前向计算

“图像池”技术:使用生成器生成的历史图像,而不是最新生成的图像,来更新判别器,以此来防止判别器过快地适应生成器的更新,增强生成器和判别器之间的对抗性。

import mindspore as ms

# 前向计算

def generator(img_a, img_b):
    fake_a = net_rg_b(img_b)
    fake_b = net_rg_a(img_a)
    rec_a = net_rg_b(fake_b)
    rec_b = net_rg_a(fake_a)
    identity_a = net_rg_b(img_a)
    identity_b = net_rg_a(img_b)
    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

lambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5

def generator_forward(img_a, img_b):
    true = Tensor(True, dtype.bool_)
    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
    loss_g_a = gan_loss(net_d_b(fake_b), true)
    loss_g_b = gan_loss(net_d_a(fake_a), true)
    loss_c_a = l1_loss(rec_a, img_a) * lambda_a
    loss_c_b = l1_loss(rec_b, img_b) * lambda_b
    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt
    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt
    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b
    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b

def generator_forward_grad(img_a, img_b):
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)
    return loss_g

def discriminator_forward(img_a, img_b, fake_a, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    loss_d = (loss_d_a + loss_d_b) * 0.5
    return loss_d

def discriminator_forward_a(img_a, fake_a):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    return loss_d_a

def discriminator_forward_b(img_b, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    return loss_d_b

# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):
    num_imgs = 0
    image1 = []
    if isinstance(images, Tensor):
        images = images.asnumpy()
    return_images = []
    for image in images:
        if num_imgs < pool_size:
            num_imgs = num_imgs + 1
            image1.append(image)
            return_images.append(image)
        else:
            if random.uniform(0, 1) > 0.5:
                random_id = random.randint(0, pool_size - 1)

                tmp = image1[random_id].copy()
                image1[random_id] = image
                return_images.append(tmp)

            else:
                return_images.append(image)
    output = Tensor(return_images, ms.float32)
    if output.ndim != 4:
        raise ValueError("img should be 4d, but get shape {}".format(output.shape))
    return output

image_pool 函数:存储生成器生成的历史图像,并根据一定的概率返回这些历史图像或最新生成的图像。

discriminator_forward 函数:对判别器训练并计算损失

generator_forward 函数:对生成器训练并计算损失

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值