【论文解析】《CVPR2020》AdversarialNAS: 以可微的方式搜索对抗生成网络

【论文解析】《CVPR2020》AdversarialNAS: 以可微的方式搜索对抗生成网络

极链AI云官网地址
点击注册(完成新人引导后即可领取体验券哦)

关注极链AI云公众号
学习更多深度学习相关知识

一、前言

本文是CVPR2020收录的NAS领域的论文,首次将可微搜索的方式引入到GAN网络结构的搜索中,不仅提高了搜索效率,而且取得了SOTA的效果。

论文链接:

https://arxiv.org/pdf/1912.02037.pdf

源码链接:

https://github.com/chengaopro/AdversarialNAS

二、简介

如图所示,由于GAN不稳定的特性,人工设计不仅需要丰富的经验,而且最重要的是,当前GAN的类型很少,主要可以分为两类:DCGAN-based和ResNet-based。使用图b)AutoGAN的方式搜索,虽然一定程度上可以缓解上述缺陷,但是还存在两个方面的问题:

  1. 虽然搜索可以看做是一个无监督的任务,我们期望一个有效的监督信号,用来指导整个搜索的过程。现有的基于NAS搜索的GAN都使用IS或FID评估网络结构的性能,并将IS或FID作为RNN控制器的reward。那么问题来了,计算IS或FID需要生成数百张图像,这个过程本身就是耗时的,使用稍微大点的数据集,所需的GPUdays就可以轻松破千,平民玩家只好望而却步。
  2. 搜索过程需要考虑生成器G和判别器D的平衡性,一旦不平衡,训练很可能就会崩掉。然而AutoGAN只搜索了生成器,判别器相当于是预设的,随着生成器的层次加深,使用预设的更深层的判别器,所以把结构定死的话,很大程度上会限制搜索到最优的结构。

为了解决上述问题,论文提出了AdversarialNAS,具体来说:首先,作者提出了一个 [公式] 量级的大搜索空间,并将搜索空间连续松弛,以便使用可微的方式搜索。其次,论文直接利用判别器D来评估生成器G的网络结构,也就是说,用判别器来监督生成器的搜索方向,使用梯度下降的方法来更新生成器的网络结构,因此这种方式消除了代理任务的额外计算成本,提高了效率。此外,为考虑判别器和生成器之间的平衡性,需要在搜索过程中动态改变判别器的结构,这时使用生成器评估判别器的结构,通过随机梯度上升的方式计算判别器损失的最小值。不难发现,搜索的过程本质就是一个对抗的机制,AdversialNAS就是使用这种对抗机制代替了代理任务,这么做不仅有利于判别器和生成器结构的平衡,更有利于寻找更优的模型。

三、解决的问题

当前有很多人工设计的GAN,在各类生成任务中也取得了出色的结果,但是由于GAN的不稳定性,导致其设计门槛比较高。尽管之前诞生了AutoGAN等NAS搜索的方法,但其本质是基于强化学习,所以在搜索效率上有待提升,且AutoGAN只搜索了生成器,而不能同时搜索判别器。论文提出了AdversarialNAS,将可微的搜索方式(具体请参考DARTS) 首次引入到GAN的搜索中,且可以同时搜索判别器和生成器。在搜索的过程中,不需要额外代理任务评估网络结构的性能,并且该搜索方法考虑了判别器和生成器结构的相关性,让它们之间保持平衡,从而训练不容易发生崩溃。最后,AdversialNAS仅用了1个GPU day就在CIFAR10上取得了10.87 FID / 8.74 IS的结果。

四、key points:

  • 首次提出基于梯度的NAS方法搜索GAN,具有更高的效率和更好的性能。设计了更大的搜索空间涵盖更多的结构,搜索时间降至1GPU days。
  • 针对GAN对抗的本质,设计了对抗搜索的策略,交替搜索判别器和生成器,提高二者之间的平衡性。

五、具体实现

搜索空间

首先,明确一下最终的目标:即搜索一系列Cells,包括上采样的Cell和下采样的Cell,并使用这些Cells搭建一个完整的GAN。我们使用3个上采样的Cell堆叠为生成器,使用4个下采样的cell作为鉴别器。卷积神经网络具有层次结构,每一层都有其独特的功能,因此正好可以用来组合具有不同结构的Cell。

如图:我们定义Cell为:一个具有N个节点的有向无环图,每个cell使用图片作为输入,输出处理后的图片。每个节点Xi表示一个中间特征,每两个节点Xi和Xj之间的边Fi,j表示一个特定操作。由于我们的最终的目标是寻找生成器的最优结构,所以为上采样的Up-Cell设计了一个几乎全联通的拓扑结构,其由4个节点(0是输入节点,本质上属于【上一层】的输出节点4)组成。后面的结点是由在其前面的结点和候选集中选择的操作两部分决定的,生成器的操作候选集如下:

  • None • Identity
  • Convolution 1x1, Dilation=1
  • Convolution 3x3, Dilation=1 • Convolution 3x3, Dilation=2
  • Convolution 5x5, Dilation=1 • Convolution 5x5, Dilation=2

其中,None表示两个cell之间没有操作,会改变网络的拓扑结构(说白了就是两个结点之间原有的连接断开了),Identity表示跳跃连接操作(为了融合多尺度特征),上述两个操作stride=1,分辨率不变。生成器的搜索空间还包含一个上采样操作的子集,其包含三种操作:

  • Transposed Convolution 3x3
  • Nearest Neighbor Interpolation
  • Bilinear Interpolation

需要说明的是,上述三个操作只能在Up-Cell的0 → 1和 0 → 2两处使用。为了用对抗的方式寻找生成器,我们还要搜索判别器,思路很简单,把生成器的结构倒过来即可。生成器的是操作候选集合和判别器相同,不过它们把上采样子集的3个操作换成了下采样的操作,具体如下:

  • Average Pooling • Max Pooling
  • Convolution 3x3, Dilation=1 • Convolution 3x3, Dilation=2
  • Convolution 5x5, Dilation=1 • Convolution 5x5, Dilation=2

注意,当stride=2时,只能在 2 → 4和 3 → 4上进行下采样操作,因此搜索空间涵盖了 [公式] 个不同的网络结构。

网络结构的连续松弛

我们每次搜索,实质上都是为每一条边选择候选集中的一个特定操作。因此,第n个Cell的某个结点Xj,可以由其前面的结点Xi、以及Xi和Xj之间的操作f共同决定,如下所示:

在f的选择上,如果是基于强化学习的NAS,就会直接从候选集里采样一个操作。受到基于梯度的NAS的启发,论文利用Gumbel-Max的技巧将f松弛到一个连续的搜索空间。

其中Of是Gumbel(0,1)分布采样的噪声,τ为softmax的温度temperature。p代表Xi和Xj相连的边上使用操作f的概率,采用Gumbel-softmax更严格地遵循学习到的概率分布。因此,每条边相当于包含一个概率向量,即:

其中,α是一个可学习的参数。因此,我们就可以将搜索网络结构的目标转化为学习每条边的最优概率向量集,并根据所学的概率分布推导出网络结构。此外,为了动态地同时改变判别器的结构,我们还引入了一组连续参数β来计算判别器中各函数的概率:

因此,判别器的可以像生成器的表示方法一样:

总的来说,论文提出的AdversialNAS以可微的方式学习一组连续参数α和β,并通过简单地保留搜索空间中最有可能的操作来获得最终的生成器架构。我们将网络中的所有操作的搜索空间用Super- G和Super-D表示。

如果你不了解DARTS,上面的描述可能会比较抽象。我们举个例子:假设我们搜索空间有6个操作(操作0~操作5),如果我们简单地从这6个操作里随机选择一个,这个方式显然是离散的,以前的基于强化学习(RL)的搜索方法就是这么干的,然后基于RL的方法还需要对每个网络结构重新评估验证,费时费力。那么基于梯度的方法,也就是本文采用的方法,首先使用一个Gumble-softmax策略(softmax的魔改版),把6个操作的概率归一化,把这六个操作的概率用一个向量表示,假设为[0.1, 0.15, 0.25, 0.4, 0.03, 0.07],这里就是将搜索空间松弛的过程。最后我们选择操作的时候,只需要挑出来向量中概率最大的,也就是0.4对应的算子,至此,我们就选到了我们想要的操作算子。对比之前的RL搜索方式,这里省去了为每个操作验证效果的步骤,从而大幅度减少了搜索时长。

对抗网络结构的搜索

首先回顾一下DARTS的更新过程:

Lval和Ltrain分别代表验证集和训练集的损失函数,算法的目标是计算和最小化验证集损失Lval,来发现最佳的网络结构α(可以理解为要搜索的参数),其中w*(可以理解为网络本身的参数)是我们计算并最小化训练集损失Ltrain得到的。我们通过梯度下降的方法对w和α进行优化。

分类任务有明确的指标(如交叉熵损失函数)可以监督搜索过程,但是针对GAN的训练而言,就没有这种明确的指标。之前的AutoGAN和AGAN使用IS来评估网络结构的性能,并且通过RL策略更新,计算IS需要生成数百个图像,而且需要离线推理,这是非常耗时的。因此论文让G和D二者相互竞争,提高两者的性能,也就是利用D引导G搜索,反之亦然。作者还是从GAN本身对抗的特性出发,将优化的过程定义为G和D的竞争策略,用V(α,β)表示,这里要保证G和D的权值都是当前最优的。

效仿DARTS的两层优化过程,其中Pdata表示数据的真实分布,Pz表示生成数据的先验分布,WD(β)表示在结构β下判别器的最优权重。WG(α)表示在α结构下生成器的最优权重。通过对比DARTS的思路,以笔者的理解,上述算法的意思是:我们要搜索网络结构α和β,首先要利用当前最优的网络结构参数WD和WG,最大化判别器判断真实数据为真的概率,最小化生成器生成的假图被判别器识别出的概率。我们用到的最优网络结构怎么来的呢?就是在当前判别器的网络结构β和生成器的网络结构α下,针对判别器D:最大化真实数据通过判别器的概率、最大化判别器识别出生成器输出的假图的概率,针对生成器G:最小化判别器判断出生成器生成的假图的概率。

介绍完搜索的过程,我们再来说一下GAN的训练损失函数,如上图所示,对于特定网络结构{α, β}的两个最佳权重{W∗G(α),W∗D(β)},可以通过WG和WD之间的另一个极小极大博弈得到。不过如果直接这么干的话,操作是费时的,在DARTS中是把训练过程近似为:∇α Lval (w∗(α), α)≈∇α Lval (w−ξ∇wLtrain (w,α),α)。受该方法的启发,对于给定的结构{α,β}相应的权值{W∗G(α),W∗D(β)}可以通过像普通GAN一样的一个step的对抗性训练得到。

不过,两种网络结构之间的博弈也可以用另一种方法进行搜索:具体来说,根据Goodfellow已证明的方法,对于给定的判别器,当前最优的生成器可以通过一个step的对抗训练获得。所以论文提出的AdversialNAS算法如下图所示,每一次迭代都可以通过提升或者降低梯度来获得最优网络结构的权重:

具体来说:首先,我们用k个step去更新判别器的结构+权重,具体来说,我们采用2m个噪声和2m个真实数据,然后用随机梯度上升的方法分别更新网络结构和权重。然后我们再采样2m个噪声,使用随机梯度下降的方法分别更新生成器的网络结构和权重。需要注意的是,无论是D还是G,都需要先更新网络结构,这就需要保证更新的网络结构的权值需要是当前最优的。

六、搜索模型的表现

搜索模型的结构

上表是模型在CIFAR10上搜索出的生成器最佳结构,可以发现:生成器倾向于选择双线性上采样的方法,可能是因为双线性上采样能够获得比最近邻上采样更精细的特征。整个结构没有膨胀卷积的出现,表明对于CIFAR这类低分辨率的图像,简单堆叠普通的卷积已经可以满足生成器的感受野。此外,随着Cell层次的加深,None操作也变得更多,也就意味着融合多尺度特征的操作变得更频繁。

从最后的评价指标来看,AdversialNAS取得了SOTA的FID评分,在IS评价指标上也仅次于Progressive GAN的表现。在搜索的计算资源消耗上,在RTX2080Ti上也仅需要1个GPU day,相比于AutoGAN和AGAN的搜索速度有了明显提升。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值