前言
今天介绍的这篇论文和前面我看过的少样本学习方面论文的侧重点不一样,这篇 Few-Shot Lifelong Learning 于2021年3月1号发布在arxiv上,一作来自印度伊利诺伊理工学院。
额,一个糟糕的…(╯‵□′)╯炸弹!•••*~●
这篇文章的研究重点在于少样本情况下的增量学习,论文题目中的"Lifelong"即表明主题,这在日常生活生产中是更为实用的方法。文章提出的算法较为简单,但是可以取得很好的效果,本文将简单介绍一下少样本增量学习与文章中的算法。
一、少样本增量学习概念
首先了解一下增量学习的概念。
百度百科中描述增量学习是指一个学习系统能不断地从新样本中学习新的知识,并能保存大部分以前已经学习到的知识。
简而言之,就是可以不停接受新知识,同时尽量不遗忘之前学过的旧知识。像极了人生啊(⊙ˍ⊙)
增量学习在人工智能的实际应用中很有意义,比如有一台可以打开柜门自主取货的贩卖机,机器可以依据训练好的分类网络,来识别买家从贩卖机中取走了哪些商品,进而计算出收款额,这样的贩卖机在现实生活中应用广泛。
现在就出现一个问题,如果我想更换贩卖机中的商品,那就需要之前的分类网络能识别出新的商品才能继续正常使用。之前训练好的分类网络只学会了区分旧商品的知识,并不能区分新的商品,这时就需要网络继续学习。并且管理者会希望网络学会区分新商品的同时,仍然能够区分旧商品,以防贩卖机中的商品又发生更换。
此时网络的继续学习就属于增量学习,它可以有两种方法。第一种是最简单的,即使用新商品与旧商品的样本从头共同训练分类网络,这种方式简单粗暴,缺点显而易见。当又有新的商品需要加入时,就需要再从头训练一遍网络,因此这种方法灵活性和效率很低。
第二种方法就是需要仔细设计算法,以求达到只使用新商品的样本,在现有网络的基础上进一步训练,使网络在能区分新商品的同时不遗忘旧商品的效果。这种方法的优势在于不用保存旧商品的样本,每次加入新商品时,只使用新商品的数据训练即可。大大提高了灵活性和效率,减少存储。
Few-Shot Lifelong Learning中提出的增量学习算法就属于上述的第二种方法,这也是增量学习的研究方向。
在理解了增量学习后,少样本增量学习的概念就容易懂了。
少样本增量就是在增量的基础上加入了少样本的需求,即要学习的新类别的有标签样本数很少。 这也非常符合实际需要,仍以上面的贩卖机为例,我们肯定希望用少量的新商品样本就可以完成分类网络的更新。
二、少样本增量学习问题描述
在第一部分中介绍了少样本增量学习的概念,在本节中将对其给出具体带有符号的问题描述,方便后续的算法介绍。(问题描述来自 Few-Shot Lifelong Learning)
首先定义一系列的有标签训练数据集,
其中,
D
(
1
)
D^{(1)}
D(1)表示基类数据集,即最先用于训练网络的数据,相当于网络获得的旧知识,
D
(
2
)
,
D
(
3
)
.
.
.
.
D^{(2)},D^{(3)}....
D(2),D(3)....表示每一次增量学习中的训练集,相当于网络要学习的新知识,
D
(
2
)
D^{(2)}
D(2)就是第一次增量学习中的训练集。其中
x
j
x_j
xj表示图像,
y
j
y_j
yj表示对应的类别标签。
同时用
L
(
t
)
L^{(t)}
L(t)表示第
t
t
t个数据集中的类别集合,并规定这一系列数据集中的类别互不交叉,即
这时将少样本的要求加入增量学习中,即要求
D
(
t
>
1
)
D^{(t>1)}
D(t>1)中的数据来自
C
C
C个类别,每个类别只有
K
K
K个样本,把这样的设置称为C-way K-shot。
论文算法需要解决的问题就是在网络能够分类 D ( 1 ) D^{(1)} D(1)的基础上,依次增加 D ( 2 ) , D ( 3 ) . . . . D^{(2)},D^{(3)}.... D(2),D(3)....数据使网络完成增量学习,希望网络能够记住旧知识并学会新知识。到最后一轮增量学习完成后,网络应该能够区分来自 D ( 1 ) , D ( 2 ) , D ( 3 ) . . . . D^{(1)},D^{(2)},D^{(3)}.... D(1),D(2),D(3)....数据集中的全部类别的图像。
三、论文算法介绍
论文给出的算法思想很简单,下面先简单介绍一下大体的算法结构:
首先使用基类数据集 D ( 1 ) D^{(1)} D(1)训练出一个分类网络,这个网络由特征提取器 Θ F \Theta_F ΘF和分类器 Θ C \Theta_C ΘC组成,这样可以使网络获得基础知识(旧知识)。训练完成后,只保留下 Θ F \Theta_F ΘF。
之后,依次使用 D ( 2 ) , D ( 3 ) . . . . D^{(2)},D^{(3)}.... D(2),D(3)....训练网络 Θ F \Theta_F ΘF,在每次增量训练时,指定网络中的一小部分参数可训练,其他参数则固定,(注意!这里是算法核心哦) 如此可以较好解决由于训练数据集过小导致的过拟合问题。
在完成最后一次增量训练后,网络就学习了这一系列数据集中的全部类别知识,可以对属于这些类别的测试图片进行类别预测。以上就是算法的整体结构。
测试图像分类方法:
在基础训练或后续的每个增量训练完成后,利用每一个类(包括旧类和新类)的训练图像特征的均值获得该类的类特征,并采用类似Prototypical Net中距离度量的方法对测试图像进行分类。
获得第
c
c
c类类特征的方式如下:
接下来具体看一下基础训练和每次的增量训练是怎么完成的。
1. 基础训练
基础训练是完成对基类数据的分类训练工作,采用基本的交叉熵损失对网络 Θ F \Theta_F ΘF和 Θ C \Theta_C ΘC进行训练即可。同时为了提升训练效果,作者还尝试加入了自监督学习的并行任务。这里作者采用的是判断图像旋转较度的自监督学习方法,即在交叉熵损失中增加了一个旋转角度分类损失。
训练完成后只保留 Θ F \Theta_F ΘF,删去 Θ C \Theta_C ΘC。
交叉熵损失表示:
2. 增量训练
在增量训练时,为了适应少样本的训练集和满足保留旧知识的要求,作者选择指定 Θ F \Theta_F ΘF中小部分的网络参数参与训练,其余参数保持不变。值得注意的是,第 t − 1 t-1 t−1次增量训练使用的数据集只有 D ( t ) D^{(t)} D(t),因此该算法不需要保存之前的数据集。
选择可训练参数
选择可训练参数时,先设置一个阈值,选择出绝对值小于阈值的网络参数作为可训练参数。 这样选择出来的网络参数由于自身的值较小,所以认为它们对旧知识的保持不起决定性作用,因此改变它们可能会使旧知识最大程度地保留下来。作者在实验中设置阈值使可训练的参数占总参数的10%。
直观的训练过程可以参考下图,图中绿色的圆点表示可训练的网络参数,其余的圆点表示固定的参数:
损失函数
训练时损失函数则由三部分构成,由于这时已经没有分类器 Θ C \Theta_C ΘC了,因此不使用交叉熵损失。此时的损失集中于提取到的图像特征上,包括三个方面:
-
triplet loss:该损失用来保证属于同一类的图像特征距离近,而属于不同类的图像特征距离远。
-
正则化损失:该项损失用来保证更新后的网络参数与原来的值相距很近,以此来减少网络对于旧知识的遗忘。
-
余弦相似度损失:该项损失使得新类的类特征远离旧类的类特征,用来保证网络区分类别的准确度。
最后对这三项损失加权求和获得总损失,用该损失来训练 Θ F \Theta_F ΘF中的可训练参数。如此便完成了一次增量训练。
四、实验结果
论文在miniImageNet、CIFAR-100和CUB200 三个少样本常用数据集上进行了增量学习的实验,下面只放上在miniImageNet的实验对比结果。
FSLL即为文章提出的算法,作者设置每次的增量学习中采用5-way 5-shot 的设置,即每个增量数据集中包含5类、每类5张训练样本。因此,miniImageNet数据集(一共100类样本,被作者分为60个基类和40个新类)可以分为9个Session,数据集中的基类数据在第一个Session中,新类数据平均分布在后面的8个Session中。
由实验结果可以看到,FSLL算法取得了最佳的效果;同时由于Session数越大,需要分类的类别数就越多,因此分类准确率会出现下降,但FSLL算法已经最大可能地保留了旧知识。
总结
少样本增量学习也是目前研究的热点,我认为它相比于传统的少样本学习而言具有更高的应用价值。
这篇论文也是我看的第一篇有关增量学习的论文,整体而言本论文中算法的思路比较清晰直接,并且也取得了较好的效果。
之后如果再看到有意思的算法还会继续分享的~