论文阅读Mark(五)Prototype-Sample Relation Distillation Towards Replay-Free Continual Learning. ICML2023

原论文

INNOVATION POINTS

本篇论文的重大创新点就是利用了prototype来控制样本之间的距离差距,尽量保持相对位置不变,以缓解Network中的遗忘。相对于目前在持续学习领域中处于主导地位的Replay方法,该方法不存储之前学习过的数据,也就不进行数据重放来巩固之前的知识,而是利用关系蒸馏和监督对比学习,解决了Replay方法的存储数据和数据隐私等问题,为replay-free的持续学习领域给予启发。

PROTOTYPE 原型

作者在文中所提到的prototype 是指一个线性层(作者代码中体现)。对于N个类,作者设置了N个prototypes,每个prototype由一个线性层构成,prototype接受一个sample作为输入,得到该sample的线性输出向量,称之为score,在之后的Objective function中作条件和constraints。
Prototype

OBJECTIVE LOSS FUNCTION

模型的损失函数由三部分构成,来优化模型网络结构和prototypes,缓解遗忘。

Supervised Contrastive Learning

作者利用SC的思想,弃用cross entropy ,使用supervised contrastive loss 作为目标函数来进行优化。
L S C ( X ) = − ∑ x i ∈ X 1 ∣ A ( i ) ∣ L S C ( x i ) L_{SC}(X) = - \sum_{x_i\in X } \frac{1}{|A(i)|}L_{SC}(x_i) LSC(X)=xiXA(i)1LSC(xi)
对于每个类的样本的loss function:
L S C ( x i ) = ∑ X p ∈ A ( i ) log ⁡ h ( g ∘ f ( x p ) , g ∘ f ( x i ) ) ∑ x a ∈ X / x i h ( g ∘ f ( x a ) , g ∘ f ( x i ) ) L_{SC}(x_i) = \sum_{X_p \in A(i)}\log \frac{h\left(g \circ f\left(\mathbf{x}_{p}\right), g \circ f\left(\mathbf{x}_{i}\right)\right)}{\sum_{\mathbf{x}_{a} \in \mathbf{X} / x_{i}} h\left(g \circ f\left(\mathbf{x}_{a}\right), g \circ f\left(\mathbf{x}_{i}\right)\right)} LSC(xi)=XpA(i)logxaX/xih(gf(xa),gf(xi))h(gf(xp),gf(xi))
h ( a , b ) = exp ⁡ ( s i m ( a , b ) / t ) h(a,b) = \exp(sim(a,b)/t) h(a,b)=exp(sim(a,b)/t) ,sim(a,b)即计算向量a,b的cos(a,b)。

Prototype Learning without Contrasts

max ⁡   L p = ( X ) = − 1 ∣ X ∣ ∑ x i , y i ∈ X , Y s i m ( p y i , s g [ f θ ( x i ) ] ) \max \space L_p = (X) = - \frac{1}{|X|} \sum_{x_i,y_i \in X,Y} sim(p_{y_i},sg[f_\theta(x_i)]) max Lp=(X)=X1xi,yiX,Ysim(pyi,sg[fθ(xi)])
sg是指梯度截断操作。该函数用来计算prototype与sample的representation之间的相似度,通过最大化该函数,对prototyp进行优化。
一旦获得了prototypes,我们即可以利用sample的representation和prototypes来计算相似度,以判断哪个类与sample最相似。

Prototypes-Samples Similarity Distillation

因为在对网络进行优化的过程中(SC loss),对参数进行更新,所以对每个样本进行forward得到的feature就会改变,进而使得prototype变得“过时”,得到的预测结果会有很大偏差(forgetting)。
计算Prototype与当前sample的Softmax输出:
P t ( p k t , X ) i = h ( p k t , f θ t ( x i ) ) ∑ x j ∈ X h ( p k t , f θ t ( x j ) ) \mathcal{P}_{t}\left(\mathbf{p}_{k}^{t}, \mathbf{X}\right)_{i}=\frac{h\left(\mathbf{p}_{k}^{t}, f_{\theta_{t}}\left(\mathbf{x}_{i}\right)\right)}{\sum_{\mathbf{x}_{\mathbf{j}} \in \mathbf{X}} h\left(\mathbf{p}_{k}^{t}, f_{\theta_{t}}\left(\mathbf{x}_{j}\right)\right)} Pt(pkt,X)i=xjXh(pkt,fθt(xj))h(pkt,fθt(xi))
作者利用KL Divergence来进行一个relation distillation,保持当前prototype与之前的prototype的相似度:
L d ( P ) = ∑ p k ∈ P o K L ( P t ( k ) ∥ P t − 1 ( k ) ) \mathcal{L}_{d}(\mathbf{P})=\sum_{\mathbf{p}_{k} \in \mathbf{P}_{o}} K L\left(\mathcal{P}_{t}(k) \| \mathcal{P}_{t-1}(k)\right) Ld(P)=pkPoKL(Pt(k)Pt1(k))

最后总的loss function是:
L ( X ) = L s c ( X ) + α L p ( X , P c ) + β L d ( X , P o ) \mathcal{L}(\mathbf{X}) = \mathcal{L_{sc}}(\mathbf{X})+\alpha\mathcal{L_p}(\mathcal{\mathbf{X}},P_c)+ \beta\mathcal{L_d}(\mathcal{\mathbf{X}},P_o) L(X)=Lsc(X)+αLp(X,Pc)+βLd(X,Po)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值