INNOVATION POINTS
本篇论文的重大创新点就是利用了prototype来控制样本之间的距离差距,尽量保持相对位置不变,以缓解Network中的遗忘。相对于目前在持续学习领域中处于主导地位的Replay方法,该方法不存储之前学习过的数据,也就不进行数据重放来巩固之前的知识,而是利用关系蒸馏和监督对比学习,解决了Replay方法的存储数据和数据隐私等问题,为replay-free的持续学习领域给予启发。
PROTOTYPE 原型
作者在文中所提到的prototype 是指一个线性层(作者代码中体现)。对于N个类,作者设置了N个prototypes,每个prototype由一个线性层构成,prototype接受一个sample作为输入,得到该sample的线性输出向量,称之为score,在之后的Objective function中作条件和constraints。
OBJECTIVE LOSS FUNCTION
模型的损失函数由三部分构成,来优化模型网络结构和prototypes,缓解遗忘。
Supervised Contrastive Learning
作者利用SC的思想,弃用cross entropy ,使用supervised contrastive loss 作为目标函数来进行优化。
L
S
C
(
X
)
=
−
∑
x
i
∈
X
1
∣
A
(
i
)
∣
L
S
C
(
x
i
)
L_{SC}(X) = - \sum_{x_i\in X } \frac{1}{|A(i)|}L_{SC}(x_i)
LSC(X)=−xi∈X∑∣A(i)∣1LSC(xi)
对于每个类的样本的loss function:
L
S
C
(
x
i
)
=
∑
X
p
∈
A
(
i
)
log
h
(
g
∘
f
(
x
p
)
,
g
∘
f
(
x
i
)
)
∑
x
a
∈
X
/
x
i
h
(
g
∘
f
(
x
a
)
,
g
∘
f
(
x
i
)
)
L_{SC}(x_i) = \sum_{X_p \in A(i)}\log \frac{h\left(g \circ f\left(\mathbf{x}_{p}\right), g \circ f\left(\mathbf{x}_{i}\right)\right)}{\sum_{\mathbf{x}_{a} \in \mathbf{X} / x_{i}} h\left(g \circ f\left(\mathbf{x}_{a}\right), g \circ f\left(\mathbf{x}_{i}\right)\right)}
LSC(xi)=Xp∈A(i)∑log∑xa∈X/xih(g∘f(xa),g∘f(xi))h(g∘f(xp),g∘f(xi))
h
(
a
,
b
)
=
exp
(
s
i
m
(
a
,
b
)
/
t
)
h(a,b) = \exp(sim(a,b)/t)
h(a,b)=exp(sim(a,b)/t) ,sim(a,b)即计算向量a,b的cos(a,b)。
Prototype Learning without Contrasts
max
L
p
=
(
X
)
=
−
1
∣
X
∣
∑
x
i
,
y
i
∈
X
,
Y
s
i
m
(
p
y
i
,
s
g
[
f
θ
(
x
i
)
]
)
\max \space L_p = (X) = - \frac{1}{|X|} \sum_{x_i,y_i \in X,Y} sim(p_{y_i},sg[f_\theta(x_i)])
max Lp=(X)=−∣X∣1xi,yi∈X,Y∑sim(pyi,sg[fθ(xi)])
sg是指梯度截断操作。该函数用来计算prototype与sample的representation之间的相似度,通过最大化该函数,对prototyp进行优化。
一旦获得了prototypes,我们即可以利用sample的representation和prototypes来计算相似度,以判断哪个类与sample最相似。
Prototypes-Samples Similarity Distillation
因为在对网络进行优化的过程中(SC loss),对参数进行更新,所以对每个样本进行forward得到的feature就会改变,进而使得prototype变得“过时”,得到的预测结果会有很大偏差(forgetting)。
计算Prototype与当前sample的Softmax输出:
P
t
(
p
k
t
,
X
)
i
=
h
(
p
k
t
,
f
θ
t
(
x
i
)
)
∑
x
j
∈
X
h
(
p
k
t
,
f
θ
t
(
x
j
)
)
\mathcal{P}_{t}\left(\mathbf{p}_{k}^{t}, \mathbf{X}\right)_{i}=\frac{h\left(\mathbf{p}_{k}^{t}, f_{\theta_{t}}\left(\mathbf{x}_{i}\right)\right)}{\sum_{\mathbf{x}_{\mathbf{j}} \in \mathbf{X}} h\left(\mathbf{p}_{k}^{t}, f_{\theta_{t}}\left(\mathbf{x}_{j}\right)\right)}
Pt(pkt,X)i=∑xj∈Xh(pkt,fθt(xj))h(pkt,fθt(xi))
作者利用KL Divergence来进行一个relation distillation,保持当前prototype与之前的prototype的相似度:
L
d
(
P
)
=
∑
p
k
∈
P
o
K
L
(
P
t
(
k
)
∥
P
t
−
1
(
k
)
)
\mathcal{L}_{d}(\mathbf{P})=\sum_{\mathbf{p}_{k} \in \mathbf{P}_{o}} K L\left(\mathcal{P}_{t}(k) \| \mathcal{P}_{t-1}(k)\right)
Ld(P)=pk∈Po∑KL(Pt(k)∥Pt−1(k))
最后总的loss function是:
L
(
X
)
=
L
s
c
(
X
)
+
α
L
p
(
X
,
P
c
)
+
β
L
d
(
X
,
P
o
)
\mathcal{L}(\mathbf{X}) = \mathcal{L_{sc}}(\mathbf{X})+\alpha\mathcal{L_p}(\mathcal{\mathbf{X}},P_c)+ \beta\mathcal{L_d}(\mathcal{\mathbf{X}},P_o)
L(X)=Lsc(X)+αLp(X,Pc)+βLd(X,Po)