DARTS+: Improved Differentiable Architecture Search with Early Stopping

DARTS+:改进的可微分架构搜索与早停机制

在这里插入图片描述

论文链接:https://arxiv.org/abs/1909.06035

项目链接:没找到

Abstract

最近,人们对神经架构设计过程的自动化越来越感兴趣,而可微分架构搜索(DARTS)方法使该过程在几个GPU天内可用。然而,当搜索epoch数变大时,通常会观察到DARTS的性能崩溃。同时,在所选的体系结构中发现了大量的“跳跃连接”。在本文中,我们认为崩溃的原因是在优化中存在过拟合。因此,我们提出了一种简单有效的算法,命名为“DARTS+”,通过在满足一定条件时“早停”搜索过程来避免崩溃并改进原有的DARTS。我们还在基准数据集和不同的搜索空间上进行了全面的实验,证明了我们的算法的有效性,在CIFAR10上的测试误差为2.32%,在CIFAR100上的测试误差为14.87%,在ImageNet上的测试误差为23.7%。我们进一步指出,通过手动设置少量搜索epoch,“早停”的想法隐式地包含在一些现有的DARTS变体中,而我们给出了“早停”的显式标准。

1 Introduction

网络结构搜索(Neural Architecture Search, NAS)在自动化机器学习(Automated Machine Learning, AutoML)中扮演着重要的角色,近年来备受关注[3,8,20 - 22,26,35,42]。可微分体系结构搜索(Differentiable Architecture Search, DARTS)[21]由于能够在快速搜索的同时达到理想的性能而受到广泛关注。特别是对连续参数的建筑搜索空间进行编码,形成one-shot模型,并通过基于梯度的双级优化训练one-shot模型进行搜索。

尽管DARTS的效率很高,但人们发现了DARTS的一个关键问题[3,4,37,41]。也就是说,经过一定的搜索周期后,所选架构中的跳跃连接数量会急剧增加,从而导致性能下降。我们把经过一定次数的epoch后性能下降的现象称为DARTS的“崩溃”。为了解决这个问题,P-DARTS[3]等一些作品设计了搜索空间正则化,以减轻搜索过程中跳跃连接的主导地位。然而,这些方法涉及更多的超参数,需要由人类专家仔细调整。此外,Single-Path NAS[30]、StacNAS[17]和SNAS[34]在DARTS中使用一级优化而不是二级优化,即架构参数和模型权重同时更新。然而,这些算法的搜索空间需要精心设计[17,21,34]。综上所述,DARTS崩溃的机制仍然是开放的。

在本文中,我们首先证明了DARTS的崩溃是由于搜索阶段的过拟合,这导致训练误差和验证误差之间有很大的差距。特别是,我们解释了为什么过拟合会导致DARTS中所选架构中出现大量的跳跃连接,从而损害所选架构的性能。为了避免DARTS崩溃,我们添加了一个简单有效的“早停”范式,称为“DARTS+”,其中搜索过程根据特定标准停止,如图1 (a)所示。指出在“提前停止”的情况下,搜索已经饱和。我们注意到DARTS的一些进展,包括P-DARTS [3], Auto-DeepLab[19]和PC-DARTS[36],也隐含地采用了早停的想法,在他们的方法中手动设置较少的搜索时代。

在这里插入图片描述

此外,我们进行了足够的实验来证明所提出的DARTS+算法的有效性。具体来说,DARTS+可以在包括DARTS、MobileNetV2、ResNet在内的各种空间上成功搜索。在DARTS搜索空间中,DARTS+在CIFAR10上的测试误差为2.32%,在CIFAR100上的测试误差为14.87%,而搜索时间小于0.4 GPU-Day。在转换到ImageNet时,如果引入SE-Module [12], DARTS+的Top-1误差达到了23.7%,达到了惊人的22.5%。DARTS+也可以直接在ImageNet上搜索,并获得23.9%的Top-1错误。

综上所述,我们的主要贡献如下:

  • 我们研究了DARTS的崩溃问题,并指出其根本原因是在DARTS训练中模型权值的过拟合。
  • 我们为DARTS引入了一种有效的“早停”范式,以避免崩溃,并提出了有效的和自适应的早停标准。
  • 我们在基准数据集和各种搜索空间上进行了广泛的实验,以证明所提出算法的有效性,该算法在所有这些数据集上都取得了最先进的结果。

2 Collapse of DARTS

DARTS有一个不希望出现的行为[21],即当搜索epoch的数量很大时,所选架构中往往会出现太多的跳跃连接,从而使性能变差。本文将这种性能下降现象称为DARTS的“崩溃”现象。在本节中,我们首先对原始的DARTS进行快速回顾,然后指出DARTS的崩溃问题并讨论其根本原因。

2.1 DARTS

DARTS的目标是搜索一个单元,这些单元可以堆叠形成卷积网络或循环网络。每个单元是 N N N个节点的有向无环图(DAG) { x i } i = 0 N − 1 \{x_{i}\}_{i=0}^{N-1} {xi}i=0N1,其中每个节点代表一个网络层。我们将操作空间记为 O \mathcal{O} O,并且每个元素都是一个候选操作,例如,0,skip-connect, convolution, max-pool等。DAG的每条边 ( i , j ) (i,j) (i,j)表示从节点 x i x_i xi到节点 x j x_j xj的信息流,该信息流由体系结构参数 α ( i , j ) α^{(i,j)} α(i,j)加权的候选操作组成。特别地,每条边 ( i , j ) (i,j) (i,j)可以用一个函数 o ˉ ( i , j ) \bar{o}^{(i,j)} oˉ(i,j)来表示,其中 o ˉ ( i , j ) ( x i ) = ∑ o ∈ O p o ( i , j ) ⋅ o ( x i ) \bar{o}^{(i,j)}(x_{i})=\sum_{o\in\mathcal{O}}p_{o}^{(i,j)}\cdot o(x_{i}) oˉ(i,j)(xi)=oOpo(i,j)o(xi);各操作 o ∈ O o \in \mathcal{O} oO的权值是体系结构参数 α ( i , j ) α ^{(i,j)} α(i,j)的软最大值,即 p o ( i , j ) = exp ⁡ ( α o ( i , j ) ) ∑ o ′ ∈ O exp ⁡ ( α o ′ ( i , j ) ) p_{o}^{(i,j)}=\frac{\exp(\alpha_{o}^{(i,j)})}{\sum_{o^{\prime}\in\mathcal{O}}\exp(\alpha_{o^{\prime}}^{(i,j)})} po(i,j)=oOexp(αo(i,j))exp(αo(i,j))。中间节点为 x j = ∑ i < j o ˉ ( i , j ) ( x i ) x_{j}=\sum_{i<j}\bar{o}^{(i,j)}(x_{i}) xj=i<joˉ(i,j)(xi),输出节点 x N − 1 x_{N−1} xN1为除输入节点外所有中间节点的深度连接。上述超网络称为one-shot模型,我们用 w w w表示超网络的权值。

对于搜索过程,我们将 L t r a i n \mathcal{L}_{train} Ltrain L v a l \mathcal{L}_{val} Lval分别表示为训练损失和验证损失。然后通过以下双层优化问题学习体系结构参数:
min ⁡ α L v a l ( w ∗ ( α ) , α ) , s . t . w ∗ ( α ) = arg ⁡ min ⁡ w L t r a i n ( w , α ) . (1) \begin{array}{rl}\min_{\alpha}&\mathcal{L}_{val}(w^*(\alpha),\alpha),\\\mathrm{s.t.}&w^*(\alpha)=\arg\min_{w}\mathcal{L}_{train}(w,\alpha).\end{array} \tag{1} minαs.t.Lval(w(α),α),w(α)=argminwLtrain(w,α).(1)
在得到体系结构参数 α α α后,得到最终的离散体系结构:1)设置 o ( i , j ) = arg ⁡ max ⁡ o ∈ O , o ≠ z e r o p o ( i , j ) o^{ (i,j)} = \arg\max_{o\in\mathcal{O},o\neq zero}p_{o}^{(i,j)} o(i,j)=argmaxoO,o=zeropo(i,j); 2)对于每个中间节点,选取 max ⁡ o ∈ O , o ≠ z e r o p o ( i , j ) \operatorname*{max}_{o\in\mathcal{O},o\neq zero}p_{o}^{(i,j)} maxoO,o=zeropo(i,j)值最大的两条入线。更多的技术细节可以在原始的DARTS论文[21]中找到。

2.2 崩溃问题

在DARTS中观察到,所选择的体系结构中涉及到大量的跳跃连接,这使得体系结构很浅,性能很差。作为一个例子,让我们考虑在CIFAR100上进行搜索。图2(i)(c)中绿线所示的跳跃连接的 α α α值随着搜索历元数的增加而变大,因此所选架构中跳跃连接的数量增加,如图2(i)(a)中绿线所示。这种浅层网络的可学习参数比深层网络少,因此表达能力较弱。因此,具有大量跳跃连接的架构性能较差,即崩溃,如图2(i)(a)中的蓝线所示,更多实验见附录C。

在这里插入图片描述

为了更直观,我们在图1(b)中的CIFAR100上绘制了来自不同搜索时期的选定架构。当搜索次数增加时,所选架构中的跳跃连接数量也会增加。这种现象也可以在其他数据集上观察到,例如CIFAR10和Tiny-ImageNet-200。

在这里插入图片描述

为了避免崩溃,有人可能会建议调整搜索超参数,例如1)调整学习率,2)改变训练和验证数据的部分,以及3)在跳跃连接(如dropout)上添加正则化。不幸的是,这些方法只是在某些搜索时期缓解了崩溃,但崩溃最终会出现,这意味着超参数的选择并不是崩溃的根本原因。

2.3 过拟合与分析

为了弄清楚DARTS的崩溃,我们观察到在搜索过程中,one-shot模型中的模型权值会出现“过拟合”。在双层优化中,模型权值 w w w由训练数据更新,结构参数 α α α由验证数据更新。由于模型的权重,one-shot模型中的 w w w被过度参数化,w倾向于很好地拟合训练数据,而验证数据由于 α α α的数量有限而处于欠拟合状态。具体来说,在CIFAR10/100数据集中,训练准确率可以达到99%,而CIFAR10的验证准确率只有88%,CIFAR100的验证准确率只有60%。这意味着“过拟合”,因为训练误差和验证误差之间的差距很大。

结果表明,模型权值的过拟合是导致模型崩溃的主要原因。特别是在初始状态下,模型权值与训练数据欠拟合,训练误差与验证误差之间的差距很小。因此,体系结构参数 α α α和模型权值 w w w可以更好地结合在一起。经过一定的搜索周期后,模型的权重会过拟合训练数据。然而,在验证数据上,它们的拟合效果不如训练数据,并且模型的第一个单元可以比最后一个单元获得相对更好的低级特征表示。

如果我们允许不同的单元在one-shot模型中具有不同的架构,那么最后的单元更有可能选择更多的跳跃连接,从而直接从第一个单元获得良好的特征表示。图2(ii)显示了我们在不同阶段搜索不同架构时学到的不同层的正常单元架构。可以看出,该算法倾向于在第一个单元中选择具有可学习操作的深度架构(图第2(ii)(a)段),而在最后一个单元中优先选择具有许多跳跃连接的架构(图2(ii)(c)段)。

在这里插入图片描述

这与前面的分析一致,最后一层将选择跳跃连接。如果像DARTS那样强迫不同的单元具有相同的体系结构,则持续的搜索和拟合将把跳跃连接从最后一个单元广播到第一个单元,使所选体系结构中的跳跃连接数量逐渐增加。我们可以看到模型权值的过拟合导致所选模型体系结构的退化。

这种过拟合现象可以进一步说明为如图3所示的综合二值分类问题。one-shot模型为2层网络,定义为 o ( x ) = w r ⊤ ( α 0 x + ( 1 − α 0 ) W x ) o(\mathbf{x})=\mathbf{w}_{r}^{\top}(\alpha_{0}\mathbf{x}+(1-\alpha_{0})\mathbf{W}\mathbf{x}) o(x)=wr(α0x+(1α0)Wx),其中 W , w r \mathbf{W},\mathbf{w}_{r} W,wr为模型权值,其中 ∥ w r ∥ = r \|\mathbf{w} _r\|= r wr=r α 0 \alpha_0 α0为结构参数(如图3(a)所示)。用于架构搜索的训练数据 T \mathcal{T} T 和验证数据 ν \nu ν 是二维特征表示,它们是高斯分布的混合,使得 T = \mathcal{T}= T= { ( x i , y i ) , y i x i   ∼   N ( μ t e , σ t 2 I ) } ,   V   =   { ( x i , y i ) , y i x i   ∼ \{(\mathbf{x}_i,y_i),y_i\mathbf{x}_i\:\sim\:N(\mu_t\mathbf{e},\sigma_t^2\mathbf{I})\},\:\mathcal{V}\:=\:\{(\mathbf{x}_i,y_i),y_i\mathbf{x}_i\:\sim {(xi,yi),yixiN(μte,σt2I)},V={(xi,yi),yixi N ( μ v e , σ v 2 I ) } N(\mu_v\mathbf{e},\sigma_v^2\mathbf{I})\} N(μve,σv2I)},其中 e = 1 2 ( 1 , 1 ) ⊤ \mathbf{e}=\frac1{\sqrt2}(1,1)^\top e=2 1(1,1)。训练标签和验证标签都是平衡的,使得标签1的数据数量与标签−1的数据数量相同。如果用DARTS搜索one-shot模型,训练和验证损失为 L t r a i n = ∑ ( x i , y i ) ∈ T l ( o ( x i ) , y i ) \mathcal{L}_{train}=\sum_{(\mathbf{x}_i,y_i)\in\mathcal{T}}l(o(\mathbf{x}_i),y_i) Ltrain=(xi,yi)Tl(o(xi),yi) L v a l = ∑ ( x i , y i ) ∈ V l ( o ( x i ) , y i ) , l ( o , y ) \mathcal{L}_{val}=\sum_{(\mathbf{x}_{i},y_{i})\in\mathcal{V}}l(o(\mathbf{x}_{i}),y_{i}),l(o,y) Lval=(xi,yi)Vl(o(xi),yi),l(o,y) = log ⁡ ( 1 + exp ⁡ ( − y o ) ) =\log(1+\exp(-yo)) =log(1+exp(yo)),则在一定条件下,将选择跳跃连接,总结如下引理。

引理1 考虑用一个二元分类问题进行搜索,其中数据和一次性模型如上所示(如图3所示)。假设(1)特征表示被归一化,使得 1 2 μ t 2 + σ t 2 = 1 2 μ v 2 + σ ^ v 2 = 1 \frac{1}{2}\mu_{t}^{2}+\sigma_{t}^{2}=\frac{1}{2}\mu_{v}^{2}+\hat{\sigma}_{v}^{2}=1 21μt2+σt2=21μv2+σ^v2=1,所述全连接(fc)层 { W x } \{W_x\} {Wx}的输出按上述归一化处理;(2)定义上述损失并进行双级优化训练。然后我们有:

P1:如果 σ t σ_t σt很小,对于任意 α 0 α_0 α0 w r → r e , W → W ∗ , \mathbf{w}_{r}\to\mathbf{re},\mathbf{W}\to\mathbf{W}^{*}, wrre,WW,,其中 W ∗ x = t x e \mathbf{W}^{*}\mathbf{x}=t_{\mathbf{x}}\mathbf{e} Wx=txe,和 t x = 2 2 + μ t 2 e ⊤ x t_{\mathbf{x}}=\frac{2}{\sqrt{2+\mu_{t}^{2}}}\mathbf{e}^{\top}\mathbf{x} tx=2+μt2 2ex

P2:若 σ t σ_t σt较小, σ v > σ 0 ( r ) σ_v > σ_0(r) σv>σ0(r),且 σ 0 ( r ) σ_0(r) σ0(r)是单调递减函数,则 d L v a l d α 0 < 0 \frac{\mathrm{d}\mathcal{L}_{\boldsymbol{val}}}{\mathrm{d}\alpha_{0}}<0 dα0dLval<0,说明 α 0 α_0 α0随梯度下降而增大。

在这里插入图片描述

证明见附录A。在本文中,引理1中讨论的特征表示对应于one-shot模型中的最后一层特征表示。具体来说,在搜索过程开始时, σ v σ_v σv较大, r r r较小;当发生过拟合时, r r r变大,使训练数据更易于分离,而 σ v σ_v σv保持较大。在搜索过程中, σ t σ_t σt趋于较小。根据引理1,在搜索开始时, σ v σ_v σv较大, r r r较小,因此 σ v < σ 0 ( r ) σ_v < σ_0(r) σv<σ0(r),则优先选择可学习的操作(图3(b)左图)。当发生过拟合时, σ v σ_v σv保持较大, σ v > σ 0 ( r ) σ_v > σ_0(r) σv>σ0(r) r r r越大, σ 0 ( r ) σ_0(r) σ0(r)越小,则倾向于选择跳跃连接(图3(b)右图)。

3 The Early Stopping Methodology

由于DARTS的崩溃问题是由2.2节指出的双级优化中一次模型的“过拟合”引起的,因此我们提出了一种简单有效的基于DARTS的“早停”范式来避免崩溃。特别是,当DARTS开始崩溃时,搜索过程应该尽早停止在自适应准则处。与最初的DARTS相比,这种范式可以带来更好的性能和更少的搜索成本。我们用DARTS+表示具有我们的早停准则的DARTS算法。

我们要强调,早停是必要的,应该给予更多的注意。研究发现,重要的连接是在训练的早期阶段确定的[1]。重要的和相应的变化发生在训练的最初阶段[7]。

除了“过拟合”问题外,“早停”的另一个动机是操作的体系结构参数 α α α的排序很重要,因为在所选的体系结构中只选择具有最大 α α α的操作。在搜索过程中,验证数据对可学习操作有不同的偏好,这与 α α α值的排序相对应。如果 α α α的秩不稳定,则结构噪声太大而无法选择;而当其趋于稳定时,最终选择的体系结构中可学习的操作不变,我们可以将该点视为饱和搜索点。图2(i)(a-b)中的红色圆圈表示不同数据集上的饱和搜索点。它验证了在此点之后,所有数据集(蓝线)上所选架构的验证精度趋于降低,即崩溃。综上所述,搜索过程可以在饱和搜索点“早停”,以选择所需的架构并避免过拟合,我们强调这一点并不意味着一次性模型的收敛。

我们首先遵循DARTS使用的基于单元的架构。第一个标准说明如下。

准则1:当一个正常单元中存在两个或两个以上的跳跃连接时,搜索过程停止。

该停止准则的主要优点是简单。与其他DARTS变体相比,DARTS+只需要在DARTS基础上进行少量修改,就能以更少的搜索时间显著提高性能。由于太多的跳跃连接会损害DARTS的性能,而适当数量的跳跃连接有助于将信息从第一层传递到最后一层,并稳定训练过程,例如ResNet[10],这使得体系结构获得更好的性能。因此,停留在标准1是一个合理的选择。

标准1中的超参数2是由P-DARTS[3]激发的,其中最终架构单元中的跳跃连接数被手动减少到2。然而,在处理跳跃连接方面,DARTS +与P-DARTS有本质上的不同。P-DARTS在搜索过程中不干预跳跃连接的数量,只是在搜索过程结束后用其他操作替换冗余的跳跃连接作为后处理。

相反,我们的DARTS+最终得到了理想的体系结构,具有适当数量的跳跃连接,以避免DARTS崩溃。它可以更直接、更有效地控制跳跃连接的数量(参见表1中的DARTS+和P-DARTS之间的性能比较)。

由于可学习操作的结构参数 α α α的稳定排序表明在DARTS中搜索过程是饱和的,我们也可以使用以下停止准则:

准则2:当可学习操作的架构参数 α α α的排序在一定的epoch数(如10个epoch)内趋于稳定时,搜索过程停止。

从图2(i)中可以看出,当准则1成立时,饱和训练点(准则2的停止点)接近停止点(图2(i)(a)中的红色虚线)。我们还注意到,这两个准则都可以自由使用,因为停止点很接近。然而,标准1更容易操作,但如果需要更精确地停止或涉及其他搜索空间,则可以使用标准2。准则2中的10个epoch是一个超参数,根据我们的实验,当6个epoch以上操作符的排名保持不变时,可以认为是稳定的,这意味着这个超参数不敏感,可以灵活选择。我们进一步指出,我们的早停范式解决了DARTS的一个内在问题,并且与其他技巧正交,因此它有可能用于其他基于DARTS的算法,以获得更好的性能。此外,与计算验证损失中的Hessian特征值等其他方法相比,我们的方法非常容易进行补充[39]。

我们注意到,最近最先进的可微分架构搜索方法也以一种特别的方式引入了早停的想法。为了避免崩溃,P-DARTS[3]使用1)搜索25个epoch而不是50个epoch,2)在跳跃连接后采用dropout,3)手动将跳跃连接的数量减少到2个。Auto-DeepLab[19]使用更少的epoch搜索架构参数,发现搜索更多的epoch并没有带来好处。PC-DARTS[36]使用部分通道连接来减少搜索时间,因此需要更多的epoch来收敛搜索。因此,设置50个训练周期也是一个隐含的早停范例,见附录C

4 Experiments and Analysis

4.1 数据集

在本节中,我们在基准分类数据集上进行了大量实验,以评估所提出的DARTS+算法的有效性。我们使用了四个流行的数据集,包括CIFAR10 [15], CIFAR100 [15], TinyImageNet-2005和ImageNet[6]。CIFAR10/100由50K训练图像和10K测试图像组成,分辨率为32 × 32。Tiny-ImageNet-200包含100K的64 × 64的训练图像和10K的测试图像。ImageNet来源于ILSVRC2012[28],其中包含120多万张训练图像和50K张验证图像。我们遵循ImageNet数据集的一般设置,其中图像被调整为224 × 224用于训练和测试。

4.2 早停机制在不同搜索空间上的有效性

为了验证在DARTS+中提前停止的有效性,我们在不同时期选择不同的架构上使用不同的数据集进行了广泛的实验。实验分体系结构搜索和体系结构评价两个阶段进行。

DARTS搜索空间。在DARTS搜索空间中,实验设置与DARTS相似。对于CIFAR10和CIFAR100,在架构搜索阶段,我们使用了与原始DARTS相同的one-shot模型,超参数与DARTS几乎相同,只是最多使用了60个epoch。在体系结构评估阶段,实验设置遵循原始DARTS,只是使用了2000个epoch以更好地收敛。对于Tiny-ImageNet-200,在搜索阶段,one-shot模型与CIFAR10/100几乎相同,只是在第一层上增加了一个stride为2的3 ×3卷积层,将输入分辨率从64×64降低到32×32。其他设置与CIFAR10/100中使用的设置相同。

我们使用标准1和标准2在DARTS搜索空间中早停。实验设置的其他细节可参考附录B

所选架构在不同时期的分类结果如图2(i)所示。我们还将两个标准下的“早停”点分别标记为“红色虚线”和“红色圆圈”。我们观察到,所选择的体系结构在更大的epoch下表现更差,这意味着原始DARTS存在崩溃问题。相反,无论数据集的类型如何,“早停”都可以在两个停止标准下生成良好的体系结构。

我们还比较了表1和图2(i)中的“早停”标准1和标准2。我们观察到这两个标准在所有数据集上都达到了相当的性能,因为停止点非常接近。

在这里插入图片描述

MobileNetV2和ResNet搜索空间。为了进一步验证DARTS+的有效性,我们使用MobileNetV2[29]和ResNet[10]作为backbone构建架构空间[2]。对于MobileNetV2搜索空间,我们引入了一组具有不同核大小和扩展比的mobile inverted bottleneck卷积(MBConv)来构建搜索块。对于ResNet搜索空间,我们通过用一组候选操作替换残差块来构建一次性模型,其中我们将跳跃连接保留在残差块中,并涉及10个候选操作。在这两个搜索空间中,softmax应用于体系结构参数来计算权重,用于确定所选的体系结构。实验在CIFAR100数据集上进行。关于架构搜索和架构评价的搜索空间和实验设置的详细信息汇总在附录B中。

由于跳跃连接不涉及搜索空间,我们使用“早停”标准2。所选架构在不同时期的分类结果如图4所示。标准2的“早停”时间用“红圈”表示。可以看出,与随机搜索的架构(epoch 0)和大epoch相比,“早停”选择的架构获得了相对最好的性能。

在这里插入图片描述

4.3 与先进水平的比较

除非特别说明,否则我们使用DARTS搜索空间和“提前停止”标准1评估DARTS+。注意,标准1和标准2的停止点在建议的搜索空间中几乎是相同的,如第4.2节所讨论的。

对于CIFAR10、CIFAR100和Tiny-Imagenet-200数据集,架构搜索和架构评估阶段的实验设置见4.2节和附录B。对于ImageNet,如下[36],one-shot模型从3个3 × 3的卷积层开始,步幅为2,将分辨率从224 × 224降低到28 × 28,其余网络由8个单元组成。我们从训练集中选择10%的数据用于更新模型权重,另外10%用于更新架构参数。在架构评估阶段,为了更好的收敛,我们训练了800个epoch, batch size为2048。其他实验设置与DARTS基本相同,见附录B

搜索结果和分析。所提出的DARTS+采用了“早停”的方法,减少了搜索时间。对于CIFAR10,使用单个Tesla V100 GPU的搜索过程需要0.4个GPU-Day,并在35个epoch 35左右停止。对于CIFAR100,搜索时间为0.2 GPU——Day,搜索过程在18个epoch 左右停止。对于TinyImageNet-200,搜索在10个epoch左右停止。对于ImageNet,搜索过程涉及200个epoch,在Tesla P100 GPU上需要6.8个GPU-Day。

从图2(i)中所示的跳跃连接数量可以看出,DARTS+搜索到的单元格中包含了少量的跳跃连接,说明在CIFAR10/100、Tiny-ImageNet-200、ImageNet三个数据集上,DARTS+都能成功搜索到。然而,由于所选择的体系结构中充满了跳跃连接,原有的DARTS无法在CIFAR100上进行搜索,并且之前大多数关于可微搜索的工作[3,21,34]都没有在ImageNet上进行搜索。

所选择的体系结构可以在附录B中找到。

基于CIFAR10和CIFAR100的架构评价。评价结果汇总如表1所示。对于从CIFAR10或CIFAR100中选择的每个单元格,我们报告两个数据集上的性能。使用简单的“提前停止”范式,我们在CIFAR10上的测试误差为2.32%,在CIFAR100上的测试误差为14.87%。所提出的DARTS+比其他改进的DARTS算法(如P-DARTS和PCDARTS)更简单、更好。ProxylessNAS使用不同的搜索空间,并且需要更多的搜索时间。此外,DARTS+比其他修改过的dart变体(包括ASAP)更容易实现。

我们进一步将初始通道数从36增加到50,并添加AutoAugment[5]和mixup[40]等增强技巧,以获得更好的效果。表1显示,DARTS+在CIFAR10和CIFAR100上的测试误差分别达到了令人印象深刻的1.68%和13.03%,这说明了DARTS+的有效性。

基于Tiny-ImageNet-200的架构评估。对于DARTS和DARTS+,我们使用直接从Tiny-ImageNet-200中搜索的架构进行评估。为了进行公平的比较,我们还从其他算法中转移了搜索到的架构。结果如表2所示。在标准1和标准2中,DARTS+达到了最先进的28.3%测试误差和27.6%测试误差。请注意,使用DARTS在Tiny-ImageNet-200上搜索的体系结构参数大小更小,性能也差得多,因为DARTS容易崩溃,而且使用DARTS搜索的体系结构包含大量的跳跃连接。

在这里插入图片描述

基于ImageNet的架构评估。我们使用直接从ImageNet中搜索的架构进行评估,并使用来自CIFAR100的架构来测试所选架构的可移植性。实验结果如表3所示。请注意,我们重新实现了PC-DARTS并报告了结果。在ImageNet上使用所提出的DARTS+进行搜索时,所选择的架构达到了令人印象深刻的23.9%/7.4%的Top-1/Top-5误差,而从CIFAR100转移来的架构达到了最先进的23.7%/7.2%的误差。结果表明,具有“早停”的DARTS在有限的时间内成功地在大规模数据集上搜索出良好的体系结构,并具有令人印象深刻的性能。

在这里插入图片描述

我们还在CIFAR100的架构中采用SE模块[12],并引入AutoAugment[5]和mixup[40]进行训练,以获得更好的模型。结果如表3所示,我们得到了22.5%/6.4%的Top-1/Top-5误差,只有额外的3M的FLOPS,显示了所选架构的有效性。

5 Conclusion

在本文中,我们进行了全面的分析和大量的实验,证明了DARTS存在崩溃问题,这主要是由于DARTS中一次模型的过拟合造成的。我们提出了“DARTS+”,其中引入了“早停”范式,以避免DARTS崩溃。实验表明,我们在有限的GPU时间内成功地搜索了包括大规模ImageNet在内的各种基准数据集,并且得到的架构在所有基准数据集上都达到了最先进的性能。此外,建议的“早停”准则可适用于不同的搜索空间,而许多最近有进展迹象的DARTS搜索网络可以使用“早停”来取得更好的结果。

Appendix

A. 引理1的证明

一堆公式就没放上来了。

图5为 α 0 = 0.5 α_0 = 0.5 α0=0.5 σ 0 ( r ) σ_0(r) σ0(r)的数值。事实上, σ 0 ( r ) σ_0(r) σ0(r) α 0 α_0 α0并不敏感。

在这里插入图片描述

B. 基本实验设置

B.1 架构搜索

DARTS搜索空间:8种候选操作,包括skip-connect、max-pool-3x3、avg-pool3x3、sep- conv3x3、sep- conv5x5、di - conv3x3、di -conv5x5、zero

所选择的体系结构如图6所示。我们观察到,由DARTS+搜索的但他包含大多数卷积和一些跳跃连接。

在这里插入图片描述

MobileNetV2搜索空间:我们遵循[2],以MobileNetV2[29]为骨干构建架构空间。我们通过一组具有不同核尺寸{3、5、7}和扩展比例{3、6}的mobile inverted bottleneck卷积(MBConv)层来初始化one-shot模型的每一层。由于输入图像大小仅为32×32,因此主干与原始MobielNetV2略有不同,如表4所示。

在这里插入图片描述

采用早停准则2所选择的体系结构如图7所示。

在这里插入图片描述

ResNet搜索空间:10个候选操作,包括zero、max-pool-3x3、avg-pool-3x3、sep-conv-3x3、sep-conv5x5、sep- conv7x7、di - conv3x3、conv3x3、conv5x5、conv7x7。

采用早停准则2所选择的体系结构如图8所示。

在这里插入图片描述

B.2 架构评估

DARTS搜索空间:表1、表2、表3

MobileNetV2和ResNet搜索空间:图4

C. 附加实验

C.1 关于DARTS崩溃的更多插图

图9、10、11

C.2 PC-DARTS的隐式早停

结构如图12所示。
在这里插入图片描述

"Diffable Architecture Search" (DARTS) 是一种自动机器学习架构搜索算法,其代码通常基于PyTorch或TensorFlow等深度学习框架编写。如果你想运行其中的代码,需要按照以下步骤进行: 1. **安装必要的库**:首先确保你的环境中已安装Python、PyTorch和相关科学计算库,如torchvision和numpy。如果要用到DARTS的官方实现,可能还需要安装fairseq(因为它是DARTS的一部分)。 2. **克隆代码库**:从GitHub或其他官方仓库(如作者的个人页面)克隆DARTS的源码。 ```bash git clone https://github.com/mlperf/models/tree/master/darts ``` 3. **设置环境**:有些DARTS模型可能需要特定版本的库。根据项目的readme文件或requirements.txt文件配置虚拟环境。 4. **理解代码结构**:研究代码结构,了解主程序(如train.py或search.py)、数据加载模块、模型定义以及实验配置。 5. **预处理数据**:如果需要,对数据集进行预处理或下载合适的预训练数据。 6. **配置和修改**:根据实验需求调整超参数,比如学习率、优化器、网络结构等。 7. **开始训练或搜索**:运行`python train.py` 或 `python search.py`,这可能会涉及到GPU资源,确保你的设备支持。 8. **监控和日志**:跟踪训练过程中的损失和其他指标,保存中间结果和最终模型。 9. **分析结果**:训练完成后,评估生成的模型性能并进行可视化分析。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值