Neural Architecture Search with Reinforcement Learning论文详解

这一篇可以算是神经网络架构搜索(NAS)的开篇之作,是谷歌的Zoph等人提出来的,有时候狭义上讲NAS时就专指这篇文章。论文采用RNN结构的Controller从搜索空间中预测生成自网络,用强化学习方法去优化controller RNN的参数。

以视觉应用的CNN模型为例,该论文的基本思想有三点:

  1. 用controller RNN去预测每一个网络层(layer)的卷积参数;
  2. 将controller RNN的输出作为强化学习中的action作用于搜索空间,设计子网络;
  3. 在训练完子网络后,将数据验证集的精度作为强化学习的reward去训练controller RNN的参数

Controller RNN的结构图如下所示。
在这里插入图片描述
Controller RNN结构

RNN是递归循环网络,上一个序列的RNN输出和隐藏层状态会作为下一个序列RNN的输入状态。本论文的controller RNN是一个2层的LSTM(长短时记忆,RNN中的一种),每层有35个隐节点。

每一次输出都是卷积操作的其中一种参数(例如:卷积核长度、宽度、步长等等),从第一层卷积层开始预测,逐步迭代预测到最后一层。假设第一层的卷积操作有5种参数,网络共有10个卷积层,那么controller RNN总共要预测50个序列。论文在controller的输出上还添加了一个anchor point,用于预测从前面某一层layer来的跳跃连接(skip connection)。

在预测出所有卷积层的参数后,把这些参数用来构建子网络模型。将训练数据分成验证集和训练集两部分,训练集用于训练子网络模型,在训练结束后,子网络模型计算在验证集上的精度,并把精度值作为reward反馈给强化学习,强化学习进而去优化controller RNN的参数。

之所以会采用强化学习的方法去训练controller RNN的参数,是因为controller RNN的输出到子网络生成验证集的精度不是可导函数的计算过程,从验证集上获得的Loss无法通过梯度下降的方法传递到controller RNN的输出上。

由于每预测出一个子网络后,均需要在数据集上训练一段时间,为了加速controller RNN的训练过程,需要在集群上训练子网络和controller RNN。论文采用parameter server的方式将K个controller的参数共享,每个controller一次共生成m个子网络,每个GPU训练一个子网络模型。将m个子网络的精度作为minibatch的数据分布式训练controller参数。

controller RNN的parameter server分布式训练如下图所示。
在这里插入图片描述
controller RNN集群分布式训练

论文以Cifar-10数据集作为实验对象。
卷积层的参数搜索空间为:

  1. 卷积核宽高:[1, 3, 5, 7]
  2. 卷积核个数:[24, 36, 48, 64]
  3. 卷积步长:[1, 2, 3]

训练的细节包括:parameter server上共享100个controller参数;每个controller一次生成8个子网络模型,800个子网络模型共在800个GPU上同时训练;每个子网络训练50个epoch,用最后5个epoch的最大验证集精度来训练controller;为了加速controller的收敛,controller预测的layer个数由少变多。

训练的过程总共搜索了12800个子网络,用时28天,在cifar-10上的实验结果如下图所示。
在这里插入图片描述
NAS强化学习方法和其他SOTA方法在cifar-10上的性能比较

从上表中可以看出,随着搜索空间选择的增多,以及搜索网络层深度的增加,NAS的网络架构参数也在增大,同时分类误差率也在减小。在同一等级的模型参数量大小条件下,分类精度达到了state-of-the-art的水平。

论文中同样也在RNN网络上进行搜索,这里就不讲细节实现和实验结果了,有兴趣的可以自行查看论文。

作为第一篇NAS论文方法,就能达到SOTA的水平,可以看出NAS的研究潜力是非常巨大的。但是我们同时也应该注意到,NAS虽然节省了人工设计的成本,却使用了更多的GPU计算资源和更长的搜索时间。从经济角度考虑,这不是一个好现象,所以迫切需要更经济、更高效的方法来改进NAS方法。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值