《Learning without Forgetting》 论文阅读笔记


原文链接: Learning without Forgetting

1. 引言

在很多实际的视觉应用中,都需要在保留旧知识的基础上学习新的知识。例如,为了建筑工地安全,已经开发了一个服装安全检测系统,可以检测工作人员是否穿了反光背心或者安全帽,但是项目负责人想系统也可以检测出工作人员是否穿了合适的工作鞋。理想的情况是,新任务可以与旧任务共享参数,而不会发生灾难性遗忘(Catastrophic Forgetting),或者就是我们还可以得到之前的数据,然后对它们一起训练。但是往往一些数据因为版权或者隐私问题无法再获取,或者数据集增加网络模型也会随着扩增才能学习到更多的任务,网络会越来越大。
在这里插入图片描述

如上图,本文将一个卷积神经网络的参数分别以 θs 和 θo 来表示。 上图是 AlexNet的结构图,θs 表示网络的前 5 层卷积层 + 后两层全连接层; 最后一层是与类别相关的输出层,其参数单独用 θo 表示。若加入一个新的分类任务,就将 new task 的参数先随机初始化,表示为 θn

目前基于已有的 θs 来学习 θn 主要有以下三种方法:

  • 特征提取(Feature extraction) : θs 和 θo 都不变,从网络中的一个或多个中间层的输出被用来训练新任务的参数 θn
  • 微调(Finetune): θs 和 θo 都进行优化,其中 θo 是固定的。 使用一个低的学习率来学习 θn 。还有一种可能是将 θs 中的前5层参数固定,来防止过拟合,而微调所有的全连接层参数,本文称这种实验为 Finetune-FC
  • 联合训练(Joint training): 所有的参数 θs,θo,θn 都进行学习优化,通常这个方法产生的结果是最优的,所以一般视为增量学习方法的性能上界(upper bound)

上述方法均有缺点:特征提取通常在新任务上表现不佳,因为共享参数不能表示一些新任务独有的特征表示。微调也因为没有旧任务样本的指导而在旧任务上表现变差,为了防止这个问题,采用复制网络并进行微调后作为网络新的分支的方法会随着任务的增加,测试时间线性增加,只微调全连接层,即 Finetune-FC 也不理想。 联合训练以前的数据可能因为隐私问题得不到,而且网络的容量也会不断增加。

这里本文提出一种名为 Learning without Forgetting (LwF)的方法,仅仅使用新任务的样本来训练网络,就可以得到在新任务和旧任务都不错的效果。本文的方法类似于联合训练,但不同的是LwF 不需要旧任务的数据和标签。主要思路如下图:

2. 相关工作

多任务学习,迁移学习和相关的方法都有很长的发展史了,简要概括,LwF 方法可以视为是 蒸馏网络(知识蒸馏,相关论文笔记可见:另一篇笔记)和微调的结合。 微调是使用一个已经训练完成的网络参数来对新的网络进行初始化,并在一个低的学习率下,更新参数,在新任务数据中重新找到一个局部最优解。而知识蒸馏则可以训练一个小的网络可以在源数据集上或者大量无标签的数据集上达到一个复杂网络的性能。本文的方法尝试使用同一个数据集对新任务进行监督学习,对旧任务进行非监督学习从而得到一个参数集(θs,θo,θn)可以在新/旧任务都表现良好。

2.1 方法对比

特征提取:[5][12]等人使用一个预训练模型去计算输入图片的特征,通常这个特征来自最后一层隐藏层或者是多个隐藏层。分类器在这些特征上进行学习,通常可以得到不错的结果,有时比人类手动选择特征更好[5]。更进一步有[13]研究如何选择超参数来达到更好的效果。特征提取不改变原始网络,可以让新任务在之前的任务提取的复杂特征上进行学习。但是,这些特征不是新任务特有的,一般都可以通过 微调 进一步优化。

微调:[6] 修改了一个预训练模型的参数来学习新的任务。输出层用增加的参数随机初始化,然后使用一个小学习率对原始网络的参数进行调整来降低新任务的损失函数。有时,网络的一部分参数也可能保持固定来防止过拟合。选择合适的超参数来训练,一般结果都会优于特征提取[6],[13]或者是从头开始训练[14][15]。微调改变共享参数 θs 来让它们对新任务更具判别性,使用低学习率是一个间接的形式来保留一些旧任务的表示结构。

多任务学习:[7]目标是整合所有任务中通用的知识从而同时改善所有的任务。对于神经网络来书,通常底层的层都是共享的,而高层的网络层都是 task-specific 的。 多任务学习需要训练时,所有任务的样本都要包含在训练集中。

在每个网络层中增加新的神经元也是一种学习新任务的时候保留旧任务参数的方法。例如[17] 提出 Deep Block-Modular Neural network,[18] 提出 Progressive Nerual network 用作增强学习。,这种方法保留原始网络的所有参数,用新增加的节点来学习新的任务,如果训练样本不充足的情况下,会比 微调和特征提取都要差,因为这些新增加的样本等于是从头开始训练。

2.2 局部相关方法

本文的工作与一些关于迁移知识的方法相关。[11] 提出了知识蒸馏,可以将一个大模型的知识传授给一个更小的网络以便于更有效的应用。这个小网络使用一个修改后的交叉熵损失函数进行训练,使得小网络的输出相似于大网络的输出。[19] 则是在更深的网络的中间层加入一些额外的指导,实现知识迁移。[20] 提出了 Net2Net 方法,可以由一个已有的网络立即生成一个更深,更宽的网络。这些方法都是想要生成一个结果相似于原网络的不同的网络结构,本文则是想要找到新的 θs, θo 可以更好的表示旧任务并学习新参数 θn

特征提取和微调是域适应或者迁移学习的特殊情况。它们不同于多任务学习的地方是,迁移学习中所有的任务不是同时被优化的。迁移学习是使用一个任务的知识来帮助学习另一个知识,例如[21][22], 域适应[23]方面也用知识蒸馏分方法来帮助训练新任务而不用保留旧任务的性能。 不管是迁移学习还是域适应都需要所有任务的训练样本,不管是有标签的还是没有标签的。

随着时间集成学习的方法,例如[24][25]也与本文有所关联。 [24]是关注点在迁移知识的时候可以灵活增加新的任务。[25]则是关注在不断构建多样的知识和经验,而这些方法都没有提到如何在没有原始训练数据的基础上仍保留旧任务的精度。[26]提供了一种方法可以有效地增加新任务,并且在只有新样本的情况下共同训练所有的任务,但是它式假设所有分类器和回归模型的权重都是线性可分解的。

2.3 同时期的方法

有两个新提出的方法在只有新任务样本的情况下不断增加和整合新任务。

A-LTM[8] : 主要是先在小的训练集上训练 old task,然后使用大的数据集来作为 new task 学习,本文则是先在大数据集(ImageNet) 上训练old task,然后从一些小数据上(VOC) 作为 new task 学习,所以 A-LTM 得到的性能比本文的要差,最后 A-LTM 给出结论是,想要保留旧任务的性能,保留旧任务的数据集是必须的。本文则是得到了相反的结果,没有原任务的数据仍然可以保留不错的性能。从现实角度,我们的思路更加贴近实际。

Less Forgetting learning [9]: 则是任务特定任务的分界线不应该改变,并且保持所有旧任务的最后一层不变,而本文的方法则是联合优化共享参数和最后一层。本文后面通过实验数据表明LwF方法更优。

3. 不遗忘学习

训练过程:
使用带有正则化的 SGD 训练网络

1) 首先固定 θs,θo 不变,然后使用新任务数据集训练 θn 直至收敛 (热身阶段,warm-up step)
2) 然后再联合训练所有参数,θs,θo ,θn直至网络收敛

损失函数:

Loss1: 正常分类网络的损失函数 + 正则项

(此处省略了正则项)

Loss2: 蒸馏损失函数

最后整个 Loss 表达式来优化 θs,θo ,θn

文中也给出了整体算法思路:

训练时间:

joint training > LwF > finetune > feature extraction

文中后面的部分做了大量的对比实验及方法对比,如果想要深入研究相关领域的同学可以自行深入研读,后面的实验内容和结果对比也是很值得学习的,限于篇幅和本文指向对相关概念性的介绍做归纳和总结,这里就略去了。 欢迎留言,交流学习!!

learning without forgetting是指在进行连续学习任务时,保持之前所学习知识的不被遗忘。为了实现learning without forgetting,可以使用PyTorch这一深度学习框架。 在PyTorch中,可以使用增量学习(incremental learning)的方法。具体步骤如下: 1. 定义初始模型:首先,定义一个初始模型,用于解决第一个学习任务。可以使用PyTorch中的Module类来创建模型,并选择适当的网络结构。 2. 学习第一个任务:使用第一个任务的数据集对模型进行训练。可以使用PyTorch提供的DataLoader类来加载数据集,使用优化器(如Adam或SGD)和损失函数(如交叉熵损失)对模型进行训练。 3. 保存模型参数:在完成第一个任务的训练后,将模型的参数保存起来。可以使用torch.save()函数将参数保存到磁盘上的文件中。 4. 准备新任务:准备新的数据集和标签,用于学习新的任务。可以使用相同的网络结构或者更改网络结构,根据新的任务要求进行适当的调整。 5. 加载之前的模型参数:在开始新的任务训练之前,使用torch.load()函数加载之前保存的模型参数。 6. 设置学习率:由于新的任务可能与之前的任务有不同的难度或重要性,可以设置不同的学习率来适应新任务的特点。可以使用PyTorch中的scheduler类或手动调整学习率。 7. 学习新任务:使用新的数据集对模型进行更新训练。可以使用先前定义的优化器和损失函数,使用torch.nn.Module的train()方法进行训练。 通过以上步骤,可以在PyTorch中实现learning without forgetting。重要的是保存和加载已训练模型参数,并根据新任务的要求进行适当的调整。同时,可以根据需要设置学习率等超参数,以更好地适应不同任务的特点。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值