论文地址:https://arxiv.org/abs/2012.04324
源码地址:尚未开源
1 Core Idea
序列化地用不同性质的样本来训练模型,会导致灾难性遗忘(catastrophic forgetting)。作者提出域泛化(domain randomization),运用图像处理,随机化当前域的数据分布。并以这个处理结果为基础,设计一个元学习(meta-learning)策略,有一个正则化项惩罚所有和模型从当前域到不同“辅助”元域的迁移有关的损失(penalizes any loss associated with transfer
ring the model from the current domain to different “auxiliary” meta-domains),同时使得模型更容易地适应新域。这些元域由随机图像处理生成。
在D1上训练的模型,在D2上测试时遗忘的程度与D2和D1的相似性有关,但我们无法控制这种相似度。受域随机化《Domain randomization for transferring deep neural networks from simulation to the real world》《Domain randomization and pyramid consistency: Simulation-to-real generalization without accessing target domain data》及单源域泛化《Model vulnerability to distributional shifts over image tranformation sets》启发,作者提出严重扰乱当前域的数据分布,使得下一个域的样本更有可能接近当前域的数据分布,从而让域自适应的过程更加light。本文使用图像转换(image transformations)来实现随机化的过程。
但又会一个疑问:是否会学到对抗迁移到新域的表征(whether we can learn representations that are inherently robust against transfer to new domains),也就是对抗在与当前域分布不同的样本上的梯度更新。作者借助元学习,并提出一个正则化策略,强制令模型在当前域中感兴趣的任务训练,同时在新域更新参数时具有弹性。总的来说,元学习需要多个不同的元任务(元域meta- domains),但在这一背景下,只有当前域的样本。为了解决这个问题,作者提出“辅助”元域,在起始分布的基础上随机化而得,在这里用到了标准图像转换。
2 Notations and problem formulation
假设要训练一个模型,依靠服从分布
的数据,解决一个任务
。实际上,我们是不知道这个分布的详细信息,但已知有一些样本
,我们聚焦于监督学习,并假设有m个训练样本
。常规情况下会用期望风险损失empirical risk minimization (ERM)来优化
,
但在fine-tune到新任务时会遗忘起始任务,也即。
假设有一个含有N个域的序列,每个域都服从分布
,并假设下一个域
(样本
)到来时,样本
就丢失了(becomes unavailable)
3 Methods
3.1 Image transformation sets
本文的核心部分之一是图像转换集合的集合,用来驱动域随机化的过程,形成“辅助元域”。假设有一个集合,里面的每一个元素都是一种特定程度的转换(例如:亮度提高10%)。并通过在这个集合抽取一个转换动作
,用到每一个已知的数据点上,得到
。这个集合里面包含了色彩/几何变换以及噪声插入的不同组合,从而生成随机的辅助元域。
Domain randomization 域随机化可以帮助缓解灾难性遗忘的问题,在把带标注样本喂进去当前模型
并优化前,用
得到
3.2 A meta-learning algorithm
训练的目标包括三部分:1)学习感兴趣的任务(注意这个和上面的
不同);2)当迁移到不同域时缓解灾难性遗忘;3)更容易适应到新域。
针对后两个问题,借助MAML的思想,但MAML认为有大量的元域数据,并可以用它们去更新元梯度,令与起始域和元域相关的损失保持在较低水准,从而解决后两个问题。
但是,当处理域时,没有办法获得其他域
的信息,因此无法使用它们作为元域。但作者提出用标准图像转换法来会生成这些元域数据
。
Optimization problem 在训练过程中,在每一次有关当前域的梯度下降之前,作者在给定的(生成的)元域数据上做了任意次数的优化。例如,如果在第t代有K个不同的元域,对于每个元域运行单独运行一次梯度下降,得到权重空间中K个不同的点,。
作者的核心思想就是用这些权重组态(weight configuration)来计算自适应过程后的关于起始域(由已知训练集
得来的)的损失
,最小化这些损失意味着更不容易产生遗忘。这些损失的和定义为
。
进一步作者计算关于元域样本的损失,并将这些损失的和定义为
。如果可以将样本从元域中分成元训练集和元测试集,最小化
相当于运行MAML算法。
将上述损失集合在一起:
在上面的阐述中,作者假设每一个辅助元域只会有一个单独的元优化步骤(a single meta-optimization step)。所以,计算梯度涉及到计算一个梯度的梯度
。实际上,作者想要在训练的过程中得到元域
(边训练边生成?)。在实现中,作者令K=1,并通过随机抽取一个不同的辅助转换,实现梯度下降。算法伪代码如下: