Learning Transferable Architectures for Scalable Image Recognition
发布于CVPR 2018
Google Brain团队
是对NAS(也是Google Brain的)的改进,NAS方法在数据集较大时计算昂贵,所以作者提出在小数据集上生成模型架构,迁移到大数据集上
主要工作:
设计了搜索空间NASNet search space,便于后续迁移;
设计正则化ScheduledDropPath,提高了NASNet泛化能力(不同于DropOut)
论文工作
基本结构和NAS一致
insight:很多CNN都是重复结构,如ResNet,所以作者的motivation是不学习整个网络结构,而只学习那些关键的convolutional cell结构((such as the repeated modules present in the Inception and ResNet models [59, 20, 60, 58]). 控制器 RNN 有可能预测出以这些图案表示的通用卷积单元。这个单元可以串联起来,处理任意空间维度和滤波深度的输入。
作者提出学习两种cell:Normal Cell(不改变图像大小)以及Reduction Cell(长宽减半)
针对CIFAR10 IMageNet,作者设计了两种网络(基本的Cell是一样的,区别在于ImageNet的Redction Cell更多,因为其图像更大)
作者提到,这两种cell的结构可以是一样的,但是不一样的时候效果更好。
类比Inception,ResNet,作者将重复次数N和filter数量作为自由参数。(后续有实验对比)
在作者提出的搜索空间里,每个cell接收两个input:hi, hi-1也就是前两个较低层cell的输出,控制器RNN递归预测整个网络结构
Controller对每个cell的预测分为B个block,每个Block有5个预测步骤,对应于5个不同的softmax分类器。
具体步骤
step3 4选择有:
step5的组合,要么相加,要么沿着filter的维度连接在一起
新创建的隐藏状态会添加进隐藏状态集合,作为后续的输入(理解下面的例子)
作者实现上选择了B=5,让RNN预测Normal Cell和Reducton Cell,总共有2×5B个预测结果。
一种例子,便于理解:
作者对比了两种参数优化策略:强化学习(同NAS)和随机搜索(均匀分布中采样,不是从RNN的softmax分类器采样了)
实验结果
实验配置:500GPUs(google爹)
RNN参数优化使用Proximal Policy Optimization(PPO)
超参数:N,filter个数
整个搜索超过4天,比以前的方法(NAS)快了7×倍(28days)
得到三个最佳的结构NASNet-A,NASNet-B,NASNet-C
提出了ScheduledDropPath,单元格的每个路径的删除概率都随着训练次数线性增加(而dropout是一定几率删除,在本实验的效果不好)
In DropPath, each path in the cell is stochastically dropped with some fixed probability during training. In our modified version, ScheduledDropPath, each path in the cell is dropped out with a probability that is linearly increased over the course of training.
实验数据
- CIFAR10数据集上,2.4%错误率 SOTA
- ImageNet,82.7% top-1,96.3% top-5 SOTA
计算需求减少28%
- 小版本的NASNet(移动平台上)74% top-1,高于SOTA 3.1%
CIFAR10,set=4 / 6,NASNet-A with cutout的效果最好
"Cutout" 是一种数据增强技术,常用于深度学习中的图像分类任务。它的主要思想是在训练过程中随机选择图像上的一个矩形区域,并将该区域的像素值置为零(或其他预定义的数值),从而在输入图像中产生一个类似于孔洞的效果。
Cutout 的主要目标是增加模型对于部分遮挡或缺失信息的鲁棒性,减轻过拟合的问题。
ImageNet上,转移CIFAR10训练出的网络架构,但是权重重新训练
作者也测试了best cell在资源受限环境下的表现(比如移动设备),准确率74.0%,效果也超越SOTA
对比了准确率和计算需求、参数规模的关系,可以看到NASNet表现也较好
最后作者也对比了随机搜索RS和强化学习RL的效率,x轴是搜索过的网络架构的数量,可以看到,随着搜索次数的增长,强化学习的参数优化效果更好