关于生成对抗网络(GAN)

一、前言

    GAN可以说是最近10年来深度学习领域最为突出的创新,GAN是深度学习和博弈论相结合的产物,GAN的诞生,在生成领域引发了一些列的创新和应用。

二、博弈论原理

    GAN的思想是是一种二人零和博弈思想(Two-player Game)。零和博弈指的是参与博弈的各方,在严格竞争下,一方的收益必然意味着另一方的损失,博弈各方的收益和损失相加总和永远为“零”,双方不存在合作的可能

    GAN中有两个博弈者,一个是生成模型(Generator,简称 G),用来生成一张真实的图片,另一个是判别模型(Discriminator,简称 D),用来判别一张图片是生成出来的还是真实存在的。整个博弈的过程如下:

  • G 生成一张图片
  • D 学习区分图片是生成的图片还是来自真实分布的图片
  • G 根据判别模型的结果改进自己,生成新的图片
  • 上述过程交替进行,直至 D 无法判断一张图片是生成出来的还是真实的而结束

     这个博弈的过程结束后,此时G就会成为一个完美的生成器,因为它生成的图片使判别模型无法区分是真实的图片还是生成的图片,此时我们可以理解为模型达到了纳什均衡GD的能力都通过博弈达到了极致,双方不分伯仲)。

     收益分析:模型未收敛之前,D 能完美的识破生成的图片,因此 D 收益总是为1,G 收益总是为-1,所以总收益为0;收敛之后,D 无法区分生成的图和真实的图,所以 D 的收益是-1,G 成功骗过了 D,所以 G 收益为1,总收益仍然是0。所以整个训练过程是一个零和博弈过程。

三、GAN的基本结构

    如下图,G 通过学习如何将噪音源合成和真实数据分布类似的样本,而 D 则通过学习如何区分G 生成的样本和真实的样本。最简单的,G 和 D 一般通过MLP来实现,GAN的变种DCGAN通过卷积神经网络实现 G 和 D,本质上任何可导系统都可以用于实现 G 和 D

四、优化策略

    GAN采用了极小极大策略来求解,关于极小极大策略可参考参考文献中的相关介绍。极小极大的形式化描述如下。

     单看该等式,就能推断出这是一个两阶段的优化。在训练算法上,内存循环最大化 D 对真实样本的判别输出概率(最大化 D 的收益),外层循环最小化判别器的收益(判别器的收益可以理解为对伪造样本的判别输出概率)。极小极大策略是一种具备攻击性的优化方法,生成器总是不断的使判别器的最大收益最小化

五、训练步骤

    在每个Epoch训练过程,先对 D 训练 k 步,然后再训练一步 G 。一个有趣的问题是,G 如果训练的次数远超过 D,则会发生模式崩塌问题,原文称之为海奥维提卡现象此时的 G 生成的数据将不具备多样性,而失去了其存在的意义,具体可以参见参考文献。 

for epoch i in [1,2,......,n] do:
    for step j in [1,2,......,k] do:
        Sample m noise samples from P(z) and m real samples from P(x)
        Update parameters of D by SGD
    end
    Sample m noise samples from P(z)
    Update parameters of G by SGD
end

六、参考文献

生成对抗网络(Generative Adversarial Networks, GANs)是由Ian Goodfellow等人在2014年提出的深度学习模型架构[^4]。GAN由两个主要组成部分组成:生成器(Generator)和判别器(Discriminator)。它们通过一种零和博弈的方式相互作用。 **生成器**:尝试学习从随机噪声(通常是高斯分布)中生成与训练数据相似的新样本。它的目标是尽可能地欺骗判别器,使其误认为生成的数据是真实的。 **判别器**:负责区分真实数据和生成的数据。它试图准确地判断输入是来自训练数据还是生成器。 GAN的工作流程如下: 1. **训练过程**:生成器接收随机噪声作为输入并生成假样本,判别器则对这些样本进行分类,判断是真样本还是假样本。生成器根据判别器的反馈更新参数以提高生成能力,判别器也相应地调整其参数以提高识别能力。 2. **对抗迭代**:这两个模型交替优化,直到达到平衡状态,即生成器可以生成足够逼真的样本,使得判别器无法准确区分开来。 **示例代码**(简化版): ```python import torch.nn as nn # 假设我们有简单的生成器和判别器结构 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # ... def forward(self, noise): # 生成器的前向传播 pass class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() # ... def forward(self, input): # 判别器的前向传播 pass # 初始化并训练GAN generator = Generator() discriminator = Discriminator() for _ in range(num_epochs): fake_data = generator(noise) real_labels = torch.ones(batch_size) fake_labels = torch.zeros(batch_size) discriminator.zero_grad() d_loss_real = discriminator(real_data).mean() d_loss_fake = discriminator(fake_data.detach()).mean() d_loss = (d_loss_real + d_loss_fake).backward() discriminator_optimizer.step() generator.zero_grad() g_loss = discriminator(generator(noise)).mean().backward() generator_optimizer.step() ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值