GAN网络调参经验

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/qq_38262728/article/details/98327082


自从 GAN 提出后,它变得越来越火热,吸引了众多的爱好者前来学习实践。

但是只要你自己去从无到有写出一个 GAN 模型并运行,除非你运气太好,大多数情况下你都会发现自己的GAN并不能很好地 work 。

下面首先对 GAN 进行简要的介绍,然后整理了我自己在 GAN 的设计网络结构、调整参数等方面的经验。

1 什么是GAN?

GAN 是一种生成模型,由知名的学者 Ian Goodfellow 首先提出,并给出了实验结果和理论推导 https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

它以造假币为例对 GAN 的工作原理进行解释,生成器(Generator)就像造假币的人,判别器(Discriminator)就像警察,原始数据的分布类比于真钱,生成的数据分布类比于假钱。

造假币的人不断模仿真钱去造假币,造出来的钱混入真钱一同交给警察去判断。造假币的人的目的是让自己造出来的假币不断逼近于真钱,而警察既需要认出假钱、又不能冤枉真钱。

两者以此方式,不断地对抗提升自己造假和打假的能力,最终理想的结果是使得造假币的人能造出几乎无法辨识的假钱成功迷惑警察。

2 GAN存在的问题

  • 训练不稳定,损失值波动幅度大
  • 判别器收敛迅速,损失值快速降到零
  • 生成器无能为力,损失函数不断增大

3 训练的经验

3.1 不要纠结于损失函数的选择

刚开始你可能会认为损失函数对结果会产生较大的影响,但是实践证明,一般来说其对结果的影响一般并没有你想象的那样大。

因此,对于 GAN 理论入门不久,正在打开实践大门的人,我的建议是选择最简单的损失函数就可以开始实验了。

因为后续还有好多事情值得你去头疼,微调损失函数可以留到最后一步再考虑。

3.2 关于增加模型的容量

当GAN生成的图像不够准确、清晰时,可尝试增加卷积层中的卷积核的大小和数量,特别是初始的卷积层。

卷积核的增大可以增加卷积的视野域,平滑卷积层的学习过程,使得训练不过分快速地收敛。

增加卷积核数(特别是生成器),可以增加网络的参数数量和复杂度,增加网络的学习能力。

但同时也可能存在,增加生成器的模型 capacity 但是对于它快速被判别器打败的事实无济于事的情况,每个人都使用不同的模型和数据,会有不同的情况,需要具体问题具体分析。

3.3 尝试改变标签

如果使用的是真实数据标签为1,生成数据标签为0的分配方法,可将其交换为真实数据标签为0,生成数据标签为1。

这个小技巧会帮助网络在早期快速进行梯度计算,帮助稳定训练过程。

此外,还可使用软标签和带噪声的标签。

所谓软标签指不是使用0和1作为标签,而是使用和0或1接近的小数来标记,这样会减弱梯度的传播速度,稳定训练。

而使用带噪声的标签指对少数的标签进行随机的扰动,这也是一个帮助训练的小技巧。

3.4 尝试使用 batch normalization

我在实践的过程中使用 batch normalization ,发现对结果的提升具有明显的帮助,它在每一层都对数据进行归一化,有利于防止数据发散,进而保护训练的过程与结果的稳定性。

3.5 尝试分次训练

对于一般的 GAN 模型和多分类问题,最好分次训练,一次只训练一个类别,以降低网络训练的难度并提高准确性。

而对于条件 GAN 等,比如可以将类比标签一同作为输入,以类别为先验条件的 GAN ,可适度增大训练的难度。

3.6 最好不要提早结束

有时候我们会看到自己模型的损失函数在几个batch训练过后就停止波动了,但是这个时候先不要为了节省时间而提前停止训练,实践证明这个时候网络很可能仍然在不断地调整结构中。

有时候损失函数也可能突然出现很大的异常波动,这个时候也不要马上提前停止训练,多观察一会儿。

非常建议在训练的过程中,通过保存等方式不断记录当前时刻下的训练结果。通过对结果图像的观察分析来判断训练的过程,损失函数可能会一时蒙蔽双眼,结果应该不会。

因此除非损失马上收敛到接近于0,否则耐心地等待网络训练完再评估结果,调整网络结构和参数。

3.7 关于k的选择

原论文中的 k 指每优化一次生成器的损失函数,优化判别器的损失函数 k 次。

但是在实验中,经常出现判别器迅速打败生成器的情况(即判别器的损失函数快速下降,生成器快速上升)。

于是常规的思路,就是增加生成器的训练次数。没训练一次判别器,训练k次生成器。这样可以增加生成器的学习次数,使得训练在开始时稍稳定。

然而实践证明,如果判别器真的比生成器强太多,这种调节k只是让结果崩溃来的晚一些。或者说只是相当于节省了少训练几次判别器的时间,稍稍提升了结果。

我个人不建议出问题就改k的习惯,还是应该从网络结构本身找问题所在才是治本的关键。

3.8 关于学习率

调整学习率是解决生成器崩溃的一剂良方。

当出现崩溃时,尝试降低学习率,可能会带来意想不到的效果。

3.9 增加噪声

与标签噪声相似,还可在数据中引入一定量的噪声,大多数情况下都能 work 。
在这里插入图片描述

3.10 不要使用性能太好的判别器

WGAN论文提到过,若使用性能过好的判别器可能会使得判别器的损失函数在训练一开始就降到非常低,后续对抗无法继续进行,或者使得训练出来的生成器性能不够好。

3.10 可以尝试最新的multi-scale gradient方法

https://arxiv.org/abs/1903.06048

对于稳定训练帮助很大。

3.11 可以尝试使用TTUR

https://arxiv.org/abs/1706.08500

对于生成器和判别器使用不同的学习率,看似简单的 trick 对结果的提升却有奇效。

3.12 使用Spectral Normalization

https://arxiv.org/abs/1802.05957

对卷积核使用Spectral Normalization,极力安利。

4 正常的损失函数波动情况

目前来看,正常的损失函数应该是:

  • 训练初始,生成器和判别器的损失函数快速波动,但是大致都分别朝着增大或减小的方向。
  • 趋于稳定后,生成器和判别器的损失函数在小的范围内做上下波动,此时模型趋于稳定。

参考

[1] https://arxiv.org/pdf/1406.2661.pdf
[2] https://mp.weixin.qq.com/s?__biz=MzUzNTA1NTQ3NA==&mid=2247486336&idx=1&sn=57c9fe8324a1addd73016c2f9dad4db8&chksm=fa8a169dcdfd9f8b17a02ab37eba61fdb3a1d89f694eaf89e2159275a553efe81c848e9a597c&mpshare=1&scene=1&srcid=&sharer_sharetime=1564758935068&sharer_shareid=f48c6499a7bee75abed9252093ec8062&key=83b29471f317cf4cb4c43b8d6f0f7141528141839d921d9fa05354867868f61243968a92b031ba8d4867003242ab09f1ca621380db5b7bc77bfcab13dc9cc7a0960adac628f5a805694c9fef0468a345&ascene=1&uin=MjYxNDk4MjcwNg%3D%3D&devicetype=Windows+10&version=62060833&lang=zh_CN&pass_ticket=RHjQjboJdAJhysQNM17TfCzpyiuR4K3LIS%2FvyT9wAnt%2BBDxNq0hsDyAO0BNEjE6l
[3] https://towardsdatascience.com/10-lessons-i-learned-training-generative-adversarial-networks-gans-for-a-year-c9071159628

                                </div>
            <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-e44c3c0e64.css" rel="stylesheet">
                </div>

更多《计算机视觉与图形学》知识,可关注下方公众号:
在这里插入图片描述

  • 20
    点赞
  • 90
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
生成对抗网络GAN)是一种深度学习模型,由生成器(Generator)和判别器(Discriminator)组成。生成器试图生成逼真的样本,而判别器则试图区分生成的样本和真实的样本。GAN在很多领域,如图像生成、语音合成等方面取得了巨大成功。 GAN的matlab代码解读如下: 1. 导入所需工具包,例如图像处理工具包、深度学习工具包等。 2. 定义生成器函数,输入为一个随机噪声向量,输出为生成的样本。生成器通常由卷积和反卷积层组成,其中反卷积层用于将随机噪声逐渐转化为生成样本。 3. 定义判别器函数,输入为一个样本,输出为该样本为真实样本的概率。判别器通常由卷积和全连接层组成,用于提取样本的特征并进行分类。 4. 定义生成器和判别器的优化器,通常使用随机梯度下降(SGD)算法。 5. 通过定义损失函数,将生成器的输出与判别器的输出进行对比,从而指导网络的训练。常用的损失函数有交叉熵损失函数、均方误差损失函数等。 6. 定义训练循环,每次循环中进行以下操作: a. 生成一个随机噪声向量,作为输入给生成器。 b. 通过生成器生成一个样本。 c. 将真实样本和生成样本输入给判别器,计算判别器的输出。 d. 计算生成器的损失函数,并更新生成器的权重。 e. 计算判别器的损失函数,并更新判别器的权重。 7. 循环进行多次实验,直到生成器能够生成逼真的样本。 通过以上步骤,GAN可以训练出逼真的样本,并且生成器和判别器会不断互相提升,达到一种平衡状态。GAN的优点在于能够生成新颖、多样的样本,但也存在一些挑战,如训练稳定性和模式崩溃等问题,需要进一步的优化调参
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值