文章目录
前言
我将看过的增量学习论文建了一个github库,方便各位阅读地址
主要工作
论文提出了一种算法,以解决增量学习中的灾难性遗忘问题,与iCaRL将特征提取器的学习与分类器分开不同,本论文提出的算法通过引入新定义的loss以及finetuning过程,在有效抵抗灾难性遗忘的前提下,允许特征提取器与分类器同时学习。
本论文提出的方法需要 e x a m p l a r examplar examplar
算法介绍
总体流程
总体分为四个流程
- 构建训练数据
- 模型训练
- finetuning
- 管理 e x a m p l a r examplar examplar
步骤一:构建训练数据
训练数据由新类别数据与examplar构成。
设有 n n n个旧类别, m m m个新类别,每个训练数据都有两个标签,第 i i i个训练数据的标签为
- 使用onehot编码的 1 ∗ ( m + n ) 1*(m+n) 1∗(m+n)的向量 p i p_i pi
- 设旧模型为 F t − 1 F_{t-1} Ft−1,训练数据为 x x x, q i = F t − 1 ( x ) q_i=F_{t-1}(x) qi=Ft−1(x), q i q_i qi为一个 1 ∗ n 1*n 1∗n维的向量
步骤二:模型训练
模型可以选用常见的CNN网络,例如ResNet32等,按照国际惯例,这一节会介绍distillation loss,作为一篇被顶会接收的论文,自然不能免俗
loss函数介绍
符号约定
符号名 | 含义 |
---|---|
N N N | 有 N N N个训练数据 |
p i p_i pi | 含义查看上一节 |
q i q_i qi | 含义查看上一节 |
q ^ i \hat q_i q^i | 新模型旧类别分支的输出,为一个 1 ∗ n 1*n 1∗n的向量 |
n n n | 旧类别分支 |
m m m | 新类别分支 |
o i o_i oi | 新模型对于第 i i i个数据的输出,为一个 ( n + m ) ∗ 1 (n+m)*1 (n+m)∗1的向量 |
Classification loss即交叉熵,如下:
L C ( w ) = − 1 N ∑ i = 1 N ∑ j = 1 n + m p i j ∗ s o f t m a x ( o i j ) L_C(w)=-\frac{1}{N}\sum_{i=1}^N\sum_{j=1}^{n+m}p_{ij}*softmax(o_{ij}) LC(w)=−N1i=1∑Nj=1∑n+mpij∗softmax(oij)
其中
s
o
f
t
m
a
x
(
o
i
j
)
=
e
o
i
j
∑
j
=
1
n
+
m
e
o
i
j
softmax(o_{ij})=\frac{e^{o_{ij}}}{\sum_{j=1}^{n+m}e^{o_{ij}}}
softmax(oij)=∑j=1n+meoijeoij
distillation loss的形式如下
L D ( w ) = − 1 N ∑ i = 1 N ∑ j = 1 n p d i s t i j log q d i s t i j L_D(w)=-\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{n}pdist_{ij}\log qdist_{ij} LD(w)=−N1i=1∑Nj=1∑npdistijlogqdistij
其中
p
d
i
s
t
i
j
=
e
q
^
i
j
t
∑
j
=
1
n
e
q
^
i
j
t
q
d
i
s
t
i
j
=
e
q
i
j
t
∑
j
=
1
n
e
q
i
j
t
pdist_{ij}=\frac{e^{\frac{\hat q_{ij}}{t}}}{\sum_{j=1}^{n}e^{\frac{\hat q_{ij}}{t}}}\\ qdist_{ij}=\frac{e^{\frac{q_{ij}}{t}}}{\sum_{j=1}^{n}e^{\frac{q_{ij}}{t}}}
pdistij=∑j=1netq^ijetq^ijqdistij=∑j=1netqijetqij
L D ( w ) L_D(w) LD(w)即让模型尽可能的记住旧类别的输出分布。t是一个超参数,在本论文中, t = 2 t=2 t=2
个人疑问
distillation loss的作用是让模型记住以往学习到的规律,相当于侧面引入了旧数据集,从而抵抗类别遗忘。
直觉上来说,distillation loss应该只对旧类别数据进行计算,但是新类别数据的旧类别分支输出仍用于计算distillation loss,论文对此给出的解释是“To reinforce the old knowledge”
我认为这种做法的出发点为:旧模型对于新类别数据的输出(经softmax处理),也是一种旧知识,也需要防止遗忘,因此,新模型对于新类别数据的旧类别输出(经softmax处理),与旧模型对于新类别数据的输出(经softmax处理)也要尽可能一致
步骤三:finetuning
使用herding selection算法,从新类别数据中抽取部分数据,构成与旧类别examplar大小相等的数据集,此时各类别数据之间类别平衡,利用该数据集,在小学习率下对模型进行微调,选用的loss函数应该是交叉熵。
步骤二使用类别不平衡的数据训练模型,会导致分类器出现分类偏好,finetuning可以在一定程度上矫正分类器的分类偏好
步骤四:管理 e x a m p l a r examplar examplar
论文给出了两类方法
- Fixed number of samples:没有内存上限,每个类别都有 M M M个数据
- Fixed memory size:内存上限为 K K K
使用herding selection算法选择新类别数据,构成新类别的 e x a m p l a r examplar examplar
实验
论文训练模型使用了数据增强,具体方式如下:
每个实验都进行了五次训练,取平均准确率
实验过程没有太多有趣的地方,在此不做过多说明
Fixed memory size
Fixed number of samples
在CIFAR100上的结果如下
img/cls表示每个examplar中图片的个数
Ablation studies
首先是选择数据构建examplar的方法,论文比对了三类方法
- herding selection:平均准确率63.6%
- random selection:平均准确率63.1%
- histogram selection:平均准确率59.1%
上述三个选择方法的解释如下:
接下来论文比对了算法各部分对准确率提升的贡献
上述模型的解释如下
个人理解
类别不平衡会导致灾难性遗忘,模型在学习旧类别时,所使用的数据是充分的,引入知识蒸馏loss,就是尽可能保留旧数据上学习到的规律,在训练时,相当于侧面引入了旧数据。
论文在distillation loss的基础上又引入了类别平衡条件下的finetuning,相当于进一步抵抗增量学习下类别不平衡的导致的分类器偏好问题,由此取得模型性能的提升。