这是CVPR 2019的论文,在我看来,一直到这篇文章,才算是对增量学习中一个基本问题进行了研究,那就是对于基于神经网络的增量学习而言,所谓的“灾难性遗忘”到底遗忘了啥?在前面几篇文章的分析中,作者大多都是给了一个较为笼统的解释,即遗忘了基于旧样本数据训练学到的模型知识,但这个知识如何表述,基本上是从蒸馏损失的角度出发来分析。在LwM这篇文章中,作者从网络得到的注意力区域图出发,重新定义了增量学习需要学习的知识,即增量学习不能遗忘,或者不能变化的,是注意力区域图。从这个角度出发,作者提出了Learning without Memorizing(LwM)算法。
一、动机
作者认为理想的增量学习系统需满足三方面要求:(1)模型可以学习流数据中的新类别,同时保存基于原始数据训练得到的模型知识;(2)模型需在所有新旧类别中表现良好;(3)内存消耗不应当随着新类别新样本的增多而增加。在作者看来,经典的LwF方法大体满足这三个要求,而至于iCaRL、end-to-end这两种方法,实际上需要使用部分旧类别数据,被作者吐槽不能称之为纯粹意义上的增量学习。在论文中,作者也只是重点与LwF方法进行性能上的比较。
二、思路
首先,作者给出如下定义:
Teacher model:Mt-1,仅基于base classes训练的模型
Student model:Mt,在Teacher model基础上,增量学习新类别数据得到的模型。
IPP(Information Preserving Penalty):约束Teacher model与Student model之间差异性的损失函数。
增量学习的任务为,在训练好Teacher model Mt-1的基础上,基于新数据训练新模型Mt,Mt与Mt-1应当具有这样的关系:在初始化(直接用Mt-1的参数初始化Mt)时,二者差异不大,但随着训练的继续,二者的差异,也就是IPP逐渐增加。一个良好的增量学习系统应当保证在使得IPP最小化基础上,损失函数最小,也就是保证对新数据和旧数据的分类性能达到一个平衡点。
对于LwF模型而言,IPP就是蒸馏损失函数LD,蒸馏损失函数的作用是保留模型Mt与Mt-1的预测结果差异。LD表现并不如预期的原因在哪里呢?在前面的几篇论文中我们可以看出,不同的作者理解不一样,iCaRL和end-to-end两篇论文主要是认为需要旧样本的参与,Large Scale IL这篇论文认为问题出在输出层上,因而提出偏置矫正的方法。在本文作者看来,LwF模型之所以表现不佳,是因为旧模型学习的知识到底是什么这一根本性问题,并没有被理解清楚,这也导致以上各类改进方法没有达到预期,甚至改的不算真正意义上的增量学习范畴。那作者认为什么是旧模型学习的知识呢?
答案是注意力图(attention map)。
attention map是一种对卷积神经网络可视化的手段,揭示了图像中的哪些点对于目标的识别具有重要影响。比如对类别A1,可能是图像中的区域B1得到的,而类别A2,可能是图像中的区域B2得到的。而常规的交叉蒸馏损失函数,其设计仅仅是在输出层做处理,保留模型的输出信息,而忽略了卷积过程中的attention map作用,因而影响了增量学习的性能。
为此,作者设计了如下的基于常规知识蒸馏和注意力知识蒸馏相结合的增量训练框架。其主体框架与LwF一致,也设计了常规的基于新类别数据的分类损失函数Classification Loss Lc,针对旧类别数据的知识蒸馏损失函数Knowledge Distillation Loss LD。除此之外,作者单独设计了基于旧类别数据建立的注意力蒸馏损失 Attention Distillation Loss LAD。如下图所示:
与LwF对比:
可以看出,如果忽略LwF中的正则化项R,则LwM比LwF就是多了一个针对注意力蒸馏损失的约束项。作者做了一个小实验,来说明为啥要加这个约束。
如上图所示,第一排图像中,在经过n个step增量训练之后,从attention map可以看出,在第n步增量训练后,map中的注意力区域已经从电话的仪表盘处移到了下方,而知识蒸馏损失为0.09,与第1步损失0.08接近,并没改变,反而是注意力蒸馏损失从0.12变成了0.82,变化较大。第二排图像中,在经过n个step增量训练之后,从attention map可以看出,在第n步增量训练后,map中的注意力区域并没有变化,一直在仪表盘上,知识蒸馏损失和注意力蒸馏损失变化均不大。这两个实验说明,相比于知识蒸馏损失函数,注意力蒸馏损失是一种更能代表模型本质知识的损失函数,需要加以约束。
理解的总体思路,下面解决最后两个问题:(1)如何计算attention map;(2)如何增加attention distillation loss
(1)如何计算attention map?前人有不少相关工作,作者选择了应用最广的Grad-CAM方法[1]。这篇文章很有意思,改天单独介绍一下。
(2)如何计算attention distillation loss?
这里Qt-1和Qt分别表示teacher model和student model中attention map向量化后的结果,i表示输入图像,c表示类别。至于为何采用l1范数而不是l2范数,纯粹是因为l1实验出来的效果更好。
三、实验
从图像上看,LwM基本达到了预期结果,注意力图的焦点基本被保持住。定量来看,有些数据集上的效果甚至超越了iCaRL这种利用了部分原始数据的方法,当然,作者这里也继续吐槽了一下iCaRL这种利用旧类别数据的方法,与它比较本身就对LwM的不公平。LwM even outperforms iCaRL on the iILSVRC-small dataset given that iCaRL has the unfair advantage of accessing the base-class data.
[1]R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, D. Batra, et al. Grad-CAM: Visual explanations from deep networks via gradient-based localization. In ICCV, pages 618–626, 2017.