【神经网络】GAN原理总结,CatGAN

定义及原理:    

       生成器 (G)generator:接收一个随机的噪声z(随机数),通过这个噪声生成图像。G的目标就是尽量生成真实的图片去欺骗判别网络D。

       判别器(D) discriminator:对接收的图片进行真假判别。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。D的目标就是尽量辨别出G生成的假图像和真实的图像。

       GAN的主要灵感来源于博弈论中零和博弈的思想,应用到深度学习神经网络上来说,就是通过G和D不断博弈,进而使G学习到数据的分布,如果用到图片生成上,则训练完成后,G可以从一段随机数中生成逼真的图像。

      训练过程中,G和D构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近 0.5(相当于随机猜测类别)

过程

  1. 第一代的Generator,然后他产生一些图片
  2. 训练产生第一代discriminator,能够区分人工产生的和真实的图片
  3. 训练第二代Generator,使其产生的图片骗过第一代discriminator
  4. 以此类推。。。

优点

  1. 只用到了反向传播
  2. 相比其他所有模型, GAN可以产生更加清晰,真实的样本
  3. GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,不管三七二十一,只要有一个的基准,直接上判别器,剩下的就交给对抗训练了

缺点

  1. 训练GAN需要达到纳什均衡,有时候可以用梯度下降法做到,有时候做不到.我们还没有找到很好的达到纳什均衡的方法,所以训练GAN相比VAE或者PixelRNN是不稳定的,但我认为在实践中它还是比训练玻尔兹曼机稳定的多
  2. GAN不适合处理离散形式的数据,比如文本
  3. GAN存在训练不稳定、梯度消失、模式崩溃的问题(目前已解决)

应用

  1. 图片生成
  2. 替换判别器为一个分类器,做多分类任务,而生成器仍然做生成任务,辅助分类器训练
  3. 和强化学习结合,目前一个比较好的例子就是seq-GAN

CatGAN

无监督的分类会被转化为一个聚类问题,通常是以某种距离作为度量准则,从而将数据划分为多个类别,而本文则是采用数据的熵来作为衡量标准构建来CatGAN (ICLR-2016) 。具体来说,对于真实的数据,模型希望判别器不仅能具有较大的确信度将其划分为真实样本,同时还有较大的确信度将数据划分到某一个现有的类别中去;而对于生成数据却不是十分确定要将其划分到哪一个现有的类别,也就是这个不确信度比较大,从而生成器的目标即为产生出那些“将其划分到某一类别中去”的确信度较高的样本,尝试骗过判别器。接下来,为了衡量这个确信程度,作者用熵来表示,熵值越大,即为越不确定;而熵值越小,则表示越确定。然后,将该确信度目标与原始GAN的真伪鉴别的优化目标结合,即得到了CatGAN的最终优化目标。

对于半监督的情况,对有标签数据计算交叉熵损失,而对无标签数据计算上面的基于熵的损失,然后在原来的目标函数的基础上进行叠加即得,当用该半监督方法进行目标识别与分类时,其效果虽然相对较优,但相对当下state-of-the-art的方法并没有比较明显的提升。但其基于熵损失的无监督训练方法却表现较好,其实验效果如下图所示,可以看到,对于如下的典型环形数据,CatGAN可以较好地找到两者的分类面,实现无监督聚类的功能。

GAN of Salimans et al. (2016)

参考:Improved Techniques for Training GANs

GAN网络使用梯度下降的方法只会找到低的损失,不能找到真正的纳什均衡。本论文中,作者通过引入了一些方法,提高网络的收敛。

原始的GAN网络的目标函数需要最大化判别网络的输出。作者提出了新的目标函数,motivation就是让生成网络产生的图片,经过判别网络后的中间层的feature 和真实图片经过判别网络的feature尽可能相同。

相比原先的方式,生成网络G产生的数据更符合数据的真实分布。作者虽然不保证能够收敛到纳什均衡点,但是在传统GAN不能稳定收敛的情况下,新的目标函数仍然有效。

判别网络从输入到输出逐层卷积,pooling,图片信息逐渐损失,因此中间层能够比输出层得到更好的原始图片的分布信息,拿中间层的feature作为目标函数比输出层的结果,能够生成图片信息更多,生成的图片会效果会更好。

  • Semi-supervised learning

对于GAN网络,可以把生成网络的输出作为第K+1类,相应的判别网络变为K+1类的分类问题。用Pmodel(y=K+1|x)Pmodel(y=K+1|x)表示生成网络的图片为假

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
GAN(Generative Adversarial Networks)是一种生成式神经网络,由生成器和判别器两部分组成。生成器的作用是生成与真实数据相似的假数据,而判别器的作用是判断输入的数据是真实数据还是生成器生成的假数据。两个网络相互对抗,通过不断的迭代训练,生成器可以生成越来越逼真的假数据,而判别器也可以越来越准确地判断真假数据。GAN的组成原理如下: 1.生成器:生成器是一个前馈神经网络,它接收一个随机噪声向量作为输入,并输出一个与真实数据相似的假数据。生成器的目标是尽可能地欺骗判别器,使其无法区分真实数据和假数据。 2.判别器:判别器是一个二分类器,它接收真实数据或生成器生成的假数据作为输入,并输出一个二元值,表示输入数据是真实数据还是假数据。判别器的目标是尽可能地准确地判断输入数据的真假。 3.对抗训练:生成器和判别器相互对抗,通过不断的迭代训练,生成器可以生成越来越逼真的假数据,而判别器也可以越来越准确地判断真假数据。 4.损失函数:GAN的损失函数由两部分组成,一部分是生成器的损失函数,另一部分是判别器的损失函数。生成器的损失函数是判别器判断生成器生成的假数据为真实数据的概率的负对数,而判别器的损失函数是真实数据和生成器生成的假数据的判别概率之和的负对数。 ```python # 以下是一个简单的GAN实现 # 生成器 generator = Sequential() generator.add(Dense(256, input_dim=100, activation='relu')) generator.add(Dense(512, activation='relu')) generator.add(Dense(1024, activation='relu')) generator.add(Dense(784, activation='tanh')) # 判别器 discriminator = Sequential() discriminator.add(Dense(1024, input_dim=784, activation='relu')) discriminator.add(Dense(512, activation='relu')) discriminator.add(Dense(256, activation='relu')) discriminator.add(Dense(1, activation='sigmoid')) # 对抗训练 gan_input = Input(shape=(100,)) fake_image = generator(gan_input) gan_output = discriminator(fake_image) gan = Model(inputs=gan_input, outputs=gan_output) discriminator.compile(loss='binary_crossentropy', optimizer='adam') gan.compile(loss='binary_crossentropy', optimizer='adam') # 相关问题: --相关问题--:
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值