Adversarial AutoAugment
摘要
- 自动学习的增强策略替代了以往的手工设计;
- AutoAugment耗时耗力,本文希望用对抗方法Adversarial AutoAugment,同时优化目标函数和增强策略搜索损失;
- 增强策略网络希望通过生成对抗增强策略来提高目标函数的损失,目标网络通过挖掘困难样本,学习更鲁棒的特征;
- 和AutoAugment相比,本文重复使用目标网络训练中的计算进行策略评估,省去对目标网络的再训练;
- 和AutoAugment相比,在计算损耗上减少了12x,在时间上节省了11x。并在CIFAR-10上实现top-1 test error 1.36%(19年sota)。
近期工作
- Smart Augmentation: 融合某一类两个或以上样本来改善目标网络的鲁棒性;结果显示,增强网络可以和目标网络同训练;
- AutoAugmentation:用RNN作为以恶样本控制器,去寻找数据集上最好的增强策略,为了减少计算,增强策略搜索是在代理任务上执行的;
- Population based augmentation(PBA):用一个曾倩策略的动态调度代替了固定的增强策略;
- Population based training(PBT)
方法
一个augmentation policy包含5个sub-policies,每个sub-policy包含两个image operations顺序执行,每个operation包含两个对应的参数:执行概率和强度。最终,最好的5个policy会合成25个sub-policy。对于mini-batch中的某张图片,仅有一个随机选择的sub-policy会被执行。
相较于AutoAugment,本文仅使用ShearX/Y, Translate X/Y, Rotate, AutoContrast, Invert, Equalize, Solarize, Posterize, Contrast, Color, Brightness, Sharpness, Cutout和Sample Pairing共计16个operations,每个operations强度离散的分为10个值。因此搜索空间为(16 x 10)^10,这里10次方代表1个policy x 5 sub-policies x 2 operations = 10。
本文将策略搜索问题转化为最大最小问题。其中最小化问题是指,最小化目标网络的损失函数:
对于每个输入样本,都会被M个增强策略为M个不同实例,对于增强策略m,损失函数为:
最大化问题,则是寻找更难的增强策略:
增强网络和AutoAugment相似,本文,RNN控制器要预测一个完整的策略包括20个离散变量。
最大的问题是:数据增强操作的不可导阻碍了目标网络F的梯度流向策略网络A。因此,考虑用REINFORCE算法:
其中,pm表示策略m的概率。为了减少θ梯度的变化,本文将一个mini-batch的损失Lm替换为Lm^(moving average,长度固定为一个epoch),然后在M个实例上normalize它,得到Lm~。
算法整体框图如下:
策略更新可视化如下: