原论文:arxiv
研究问题:Catastrophic Forgetting
文章提出了一种“新”的持续学习方法。构造了一个结合特征生成重放和特征蒸馏的模型。
文中作者论证了其模型的高效计算能力以及易于扩展到大的数据集上的优点。
文章创新点
- 使用生成特征而不是生成数据(图像)的形式来对过去学习过的旧数据进行回放。
其原因是抽象出来的图像特征分布比图像的像素分布简单的多。 - 得益于使用生成特征的方法,存储的数据量远远小于使用生成数据的方法,节省了大量空间。
模型网络构成
- 特征提取器,包含特征蒸馏器
- 分类器,包含特征生成重放
Figures & Tables
模型训练图
F
t
F_t
Ft是特征提取器,
H
t
H_t
Ht是分类器。
模型训练算法
模型在训练时,首先第一轮初始化
F
1
,
H
1
,
G
1
F_1,H_1,G_1
F1,H1,G1。利用真实数据D 来训练F和H,将F的结果作为G的输入。
因为G是用来生成特征图而不是图像数据。
训练F:用真实数据分布D来训练特征提取器F,提取图像的特征
训练H:接收F产生的特征图,以及G生成的特征图,以此来进行训练。训练目标是判断输入的特征图属于哪一类;并能够区分该特征图是F提取的还是G生成的。
训练G:利用F产生的特征图和上一次迭代的G产生的特征图进行训练。训练目标是最小化与F产生的特征图的loss;并且记住上一次迭代中已学到的知识。
Loss Functions
L D t W G A N ( X t ) = + E z ∼ p z , c ∈ C t [ D t ( c , G t ( c , z ) ) ] − E u ∼ D t [ D t ( c , F t ( x ) ) ] L G t W G A N ( X t ) = − E z ∼ p z , c ∈ C t [ D t ( c , G t ( c , z ) ) ] \begin{aligned} \mathcal{L}_{D_{t}}^{\mathrm{WGAN}}\left(\mathcal{X}_{t}\right)= & +\mathbb{E}_{\mathbf{z} \sim p_{z}, c \in C_{t}}\left[D_{t}\left(c, G_{t}(c, \mathbf{z})\right)\right] -\mathbb{E}_{\mathbf{u} \sim \mathcal{D}_{t}}\left[D_{t}\left(c, F_{t}(\mathbf{x})\right)\right] \\ \mathcal{L}_{G_{t}}^{\mathrm{WGAN}}\left(\mathcal{X}_{t}\right)= & -\mathbb{E}_{\mathbf{z} \sim p_{z}, c \in C_{t}}\left[D_{t}\left(c, G_{t}(c, \mathbf{z})\right)\right] \\ \end{aligned} LDtWGAN(Xt)=LGtWGAN(Xt)=+Ez∼pz,c∈Ct[Dt(c,Gt(c,z))]−Eu∼Dt[Dt(c,Ft(x))]−Ez∼pz,c∈Ct[Dt(c,Gt(c,z))]
Replay alignment loss
L G t R A = Σ j = 1 t − 1 Σ c ∈ C j E z ∼ p z [ ∥ G t ( c , z ) − G t − 1 ( c , z ) ∥ 2 2 ] . \mathcal{L}_{G_{t}}^{\mathrm{RA}}=\Sigma_{j=1}^{t-1} \Sigma_{c \in C_{j}} \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}}\left[\left\|G_{t}(c, \mathbf{z})-G_{t-1}(c, \mathbf{z})\right\|_{2}^{2}\right] . LGtRA=Σj=1t−1Σc∈CjEz∼pz[∥Gt(c,z)−Gt−1(c,z)∥22].
实验部分
Result on ImageNet-Subset
由上图可以看出,作者提出的模型在ImageNet上效果非常好,比reblance 和作者自己的另一个模型(采用Gaussian replay)更好。作者的模型在训练过程中,不需要存储任何过去任务的样本,而且能够动态的结合当前任务的真实数据由G生成特征图。
Results on CIFAR-100
由上图可以看出,在CIFAR-100上,作者的模型相比Rebalance而言,差距已然缩小了许多,在25-tasks设置的训练中,甚至比Rebanlace差一点;而相比作者的另一个模型(采用Gaussian replay)差距仍然明显。
作者解释说在低分辨率的图像上提取的特征不够好,可能是ImageNet数据的分辨率比CIFAR-100的要高。