深度学习论文笔记(增量学习)——End-to-End Incremental Learning

前言

我将看过的增量学习论文建了一个github库,方便各位阅读地址

主要工作

论文提出了一种算法,以解决增量学习中的灾难性遗忘问题,与iCaRL将特征提取器的学习与分类器分开不同,本论文提出的算法通过引入新定义的loss以及finetuning过程,在有效抵抗灾难性遗忘的前提下,允许特征提取器与分类器同时学习。

本论文提出的方法需要 e x a m p l a r examplar examplar


算法介绍


总体流程

在这里插入图片描述
总体分为四个流程

  1. 构建训练数据
  2. 模型训练
  3. finetuning
  4. 管理 e x a m p l a r examplar examplar

步骤一:构建训练数据

训练数据由新类别数据与examplar构成。

设有 n n n个旧类别, m m m个新类别,每个训练数据都有两个标签,第 i i i个训练数据的标签为

  1. 使用onehot编码的 1 ∗ ( m + n ) 1*(m+n) 1(m+n)的向量 p i p_i pi
  2. 设旧模型为 F t − 1 F_{t-1} Ft1,训练数据为 x x x q i = F t − 1 ( x ) q_i=F_{t-1}(x) qi=Ft1(x) q i q_i qi为一个 1 ∗ n 1*n 1n维的向量

步骤二:模型训练

模型可以选用常见的CNN网络,例如ResNet32等,按照国际惯例,这一节会介绍distillation loss,作为一篇被顶会接收的论文,自然不能免俗


loss函数介绍

符号约定

符号名含义
N N N N N N个训练数据
p i p_i pi含义查看上一节
q i q_i qi含义查看上一节
q ^ i \hat q_i q^i新模型旧类别分支的输出,为一个 1 ∗ n 1*n 1n的向量
n n n旧类别分支
m m m新类别分支
o i o_i oi新模型对于第 i i i个数据的输出,为一个 ( n + m ) ∗ 1 (n+m)*1 (n+m)1的向量

Classification loss即交叉熵,如下:

L C ( w ) = − 1 N ∑ i = 1 N ∑ j = 1 n + m p i j ∗ s o f t m a x ( o i j ) L_C(w)=-\frac{1}{N}\sum_{i=1}^N\sum_{j=1}^{n+m}p_{ij}*softmax(o_{ij}) LC(w)=N1i=1Nj=1n+mpijsoftmax(oij)

其中
s o f t m a x ( o i j ) = e o i j ∑ j = 1 n + m e o i j softmax(o_{ij})=\frac{e^{o_{ij}}}{\sum_{j=1}^{n+m}e^{o_{ij}}} softmax(oij)=j=1n+meoijeoij


distillation loss的形式如下

L D ( w ) = − 1 N ∑ i = 1 N ∑ j = 1 n p d i s t i j log ⁡ q d i s t i j L_D(w)=-\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{n}pdist_{ij}\log qdist_{ij} LD(w)=N1i=1Nj=1npdistijlogqdistij

其中
p d i s t i j = e q ^ i j t ∑ j = 1 n e q ^ i j t q d i s t i j = e q i j t ∑ j = 1 n e q i j t pdist_{ij}=\frac{e^{\frac{\hat q_{ij}}{t}}}{\sum_{j=1}^{n}e^{\frac{\hat q_{ij}}{t}}}\\ qdist_{ij}=\frac{e^{\frac{q_{ij}}{t}}}{\sum_{j=1}^{n}e^{\frac{q_{ij}}{t}}} pdistij=j=1netq^ijetq^ijqdistij=j=1netqijetqij

L D ( w ) L_D(w) LD(w)即让模型尽可能的记住旧类别的输出分布。t是一个超参数,在本论文中, t = 2 t=2 t=2


个人疑问

distillation loss的作用是让模型记住以往学习到的规律,相当于侧面引入了旧数据集,从而抵抗类别遗忘。

直觉上来说,distillation loss应该只对旧类别数据进行计算,但是新类别数据的旧类别分支输出仍用于计算distillation loss,论文对此给出的解释是“To reinforce the old knowledge”

我认为这种做法的出发点为:旧模型对于新类别数据的输出(经softmax处理),也是一种旧知识,也需要防止遗忘,因此,新模型对于新类别数据的旧类别输出(经softmax处理),与旧模型对于新类别数据的输出(经softmax处理)也要尽可能一致


步骤三:finetuning

使用herding selection算法,从新类别数据中抽取部分数据,构成与旧类别examplar大小相等的数据集,此时各类别数据之间类别平衡,利用该数据集,在小学习率下对模型进行微调,选用的loss函数应该是交叉熵。

步骤二使用类别不平衡的数据训练模型,会导致分类器出现分类偏好,finetuning可以在一定程度上矫正分类器的分类偏好


步骤四:管理 e x a m p l a r examplar examplar

论文给出了两类方法

  1. Fixed number of samples:没有内存上限,每个类别都有 M M M个数据
  2. Fixed memory size:内存上限为 K K K

使用herding selection算法选择新类别数据,构成新类别的 e x a m p l a r examplar examplar


实验

论文训练模型使用了数据增强,具体方式如下:
在这里插入图片描述
每个实验都进行了五次训练,取平均准确率
实验过程没有太多有趣的地方,在此不做过多说明

Fixed memory size

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述


Fixed number of samples

在CIFAR100上的结果如下
在这里插入图片描述
img/cls表示每个examplar中图片的个数


Ablation studies

首先是选择数据构建examplar的方法,论文比对了三类方法

  1. herding selection:平均准确率63.6%
  2. random selection:平均准确率63.1%
  3. histogram selection:平均准确率59.1%

上述三个选择方法的解释如下:
在这里插入图片描述
接下来论文比对了算法各部分对准确率提升的贡献
在这里插入图片描述
上述模型的解释如下
在这里插入图片描述

个人理解

类别不平衡会导致灾难性遗忘,模型在学习旧类别时,所使用的数据是充分的,引入知识蒸馏loss,就是尽可能保留旧数据上学习到的规律,在训练时,相当于侧面引入了旧数据。

论文在distillation loss的基础上又引入了类别平衡条件下的finetuning,相当于进一步抵抗增量学习下类别不平衡的导致的分类器偏好问题,由此取得模型性能的提升。

  • 8
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值