ImprovedGAN论文略读

论文:https://arxiv.org/abs/1606.03498

源码:https://github.com/openai/improved_gan

参考:博客1博客2博客3

1  介绍

        GAN是基于博弈论的生成模型方法。GAN训练一个生成网络来生成尽可能真实的图像,一个判别网络尽可能区分真是图像和生成图像。训练GAN要求找到在连续高维参数下非凸博弈的纳什均衡。但是通常GAN用梯度下降方法去寻找损失函数的最小值,而不是纳什均衡,这可能会导致无法收敛。论文中介绍了几个方法去尽量实现GAN博弈的收敛。这些方法的灵感来源于非凸问题的启发式理解。可以帮助提升半监督学习性能和提升采样生成。

2  相关tricks

        首先需要对G做一个处理,使用Feature matching代替原来的关于G的loss函数。

2.1  Feature matching(特征匹配)

        Feature matching 中提出,在D(判别器)模型中,添加中间层,称之为f(x)。也就是,不是像以前一样考虑最后的结果(概率),而是考虑在中间层中的数据之间的差异性(这里使用的是二范数的平方)。因为原始的GAN网络的目标函数需要最大化判别网络的输出。作者提出了新的目标函数,目的是让生成网络产生的图片经过判别网络后的中间层的feature 和真实图片经过判别网络的feature尽可能相同。因此生成网络的目标函数定义如下:

判别网络按照原来的方式训练。相比原先的方式,生成网络G产生的数据更符合数据的真实分布。尽管不能保证到达均衡点,不过收敛的稳定性应该是有所提高。

2.2  minibatch discrimination(小批量判别)

        判别网络如果每次只看单张图片,如果判断为真的话,那么生成网络就会认为这里一个优化的目标,导致生成网络会快速收敛到当前点。作者使用了minibatch的方法,每次判别网络输入一批数据进行判断。

        假设f(x)\in R^{A}表示判别网络中间层的输出向量,作者将f(x)乘以矩阵T\in R^{A\times B\times C},得到一个矩阵M_{i}\in R^{B\times C}。计算矩阵M_{i}每行的L-1距离,得到c_b(x_i,x_j)=exp(-\left \| M_{i,b}-M_{j,b} \right \|_{L1}) \in R。定义输入x_i的输出o(x_i)如下:

o(x_i)_b=\sum_{j=1}^{n} c_b(x_i,x_j)\in R

o(x_i)=[o(x_i)_1,o(x_i)_2,o(x_i)_3,......o(x_i)_n]

 o(X) \in R^{n \times B}

o(x_i)作为输入,进入判别网络下一层的输入。

 2.3  Historical averaging(历史平均)

         在生成网络和判别网络的损失函数中添加一个项:

\left \| \theta -\frac{1}{t} \sum_{i=1}^{t} \theta [i] \right \|^{2}

公式中 \theta [i]表示在i时刻的参数。这个项在网络训练过程中,也会更新。加入这个项后,梯度就不容易进入稳定的轨道,能够继续向均衡点更新。

2.4  One-side label smooth(类别标签平滑

        将正例label乘以\alpha,, 负例label乘以\beta,最优的判别函数分类器变为:

 文中将正例label乘以\alpha,, 负例label乘以0。

2.5  Virtual batch normalization(虚拟的batch normalization)

        Normalization(归一化)使用能够提高网络的收敛,但是BN(批归一化)带来了一个问题,就是layer的输出和本次batch内的其他输入相关。为了避免这个问题,作者提出了一种新的bn方法,叫做virtual batch normalization。首先从训练集中拿出一个batch在训练开始前固定起来,算出这个特定batch的均值和方差,进行更新训练中的其他batch。VBN的缺点也显而易见,就是需要更新两份参数,比较耗时。

3  Semi-supervised learning(半监督学习)

        标准的分类网络将数据xx输出为可能的K个classes,然后对K维的向量(l_1,......,l_k)使用softmax:

p_{model}(y=j|x)=\frac{exp(l_j)}{\sum_{k=1}^{k}exp(l_k)}

 标准的分类是有监督的学习,模型通过最小化交叉熵损失,获得最优的网络参数。对于GAN网络,可以把生成网络的输出作为第k+1类,相应的判别网络变为k+1类的分类问题。用P_{model}(y=k+1|x)来表示生成网络的图片为假,用来代替GAN的1-D(x)。对分类网络,只需要知道某一张图片属于哪一类,不用明确知道这个类是什么,通过P_{model}(y \in 1,2,......,k|x)可以训练。因此损失函数变为:

 如果把D(x)=1-p_{model}(y=k+1|x),上述无监督的表达式就是GAN的形式(见2014的GAN论文https://blog.csdn.net/weixin_44855366/article/details/119734833):

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值