针对小样本的OOD检测方法
论文地址:OOD-MAML (于2020年发表在NIPS。)
前言
上一次集中分享了集中MAML的变体算法,包括Reptile,DKT,MTNET,CAVIA,TAML,Pruning等。今天让我们一起来看一篇新的改进思路——分布外样本检测和分类的MAML算法。其核心思路是:把N way K shot分类问题扩展为N+1 way K shot问题,增加的一类是未知的OOD样本。把一个N way的任务划分为N个子任务,每个任务判断是否属于该类。若不属于任何一类,则说明该样本是分布外样本。 该思路巧妙、经典、值得借鉴,一起来看看吧。
一、摘要
本文首先提出小样本学习面临的两类challenge,其一是缺少从已知类中学习训练数据的分布,该问题可以用元学习的方法解决;其二是训练时缺少分布外样本,针对该问题,作者提出了利用OOD-MAML的方法解决。
本文贡献:1.在完成小样本分类问题的同时可以检测分布外样本。2.将N way分类任务转换为N个子任务,可以处理训练和测试N不同的情况。
二、相关工作
作者发现,深度神经网络(DNN)在面对分布外样本时会产生过高的置信度,也就是误判概率高的问题。面对这样的问题,一些做法是提供不确定性估计(Uncertainty qualification, UQ)。一些思路包括使用Softmax scores、改进的Softmax scores以及马氏距离 (MAH)方法等解处理分布外检测问题。
三、方法论
作者首先介绍了元学习的基本设置,这里就略去不讲了,有疑问的可以看我系列文章:基于MAML的改进方法总结
OOD-MAML的核心思想
这里我用一个N=3的样本分类任务举例。假设任务是猫、狗和马的分类任务,用one-hot encoding独热编码来表示,即:猫:(1,0,0),狗:(0,1,0),马:(0,0,1),我们这里多加一类:(0,0,0)。如果机器判断为(0,0,0),则说明是OOD样本。进而我们把N=3的样本分类任务变为三个子任务,即第一个任务分辨图片是不是猫,是为1否为0,接着重复上述操作判断是否为狗和马。若上述三个都判断为0就是OOD样本了。综上,我们用N=4的样本分类任务解决N=3的样本分类及OOD检测问题。
元训练阶段
通过前文叙述,相信大家已经很好地理解了上述过程。下面通过流程图进一步分析,注意,本文的任务设定是每个任务只包含一类样本,只判断是或不是该类(Task setting in OOD-MAML: We construct
D
t
r
a
i
n
∈
D
m
e
t
a
−
t
r
a
i
n
{D_{train}} \in {D_{meta - train}}
Dtrain∈Dmeta−train to contain
K
K
K examples of one known class.):
这里需要注意的是,在元训练阶段OOD样本是用噪声合成的伪样本,而在元测试阶段是真正的分布外样本。 作者阐述该观点的方式被我有失偏颇地总结如下:
- 由于在训练阶段并不知道OOD样本的特征,因此在没有先验的情况下用噪音手动生成伪样本比真样本合适,否则容易导致分类器有偏;
- 如果使用任务无关的伪样本,这意味着分类器并不能更好地学到对当前任务敏感的边界(sharp decision boundary),所以伪样本是和当前任务相关的,需要因任务而适应(作者实验中做图再次说明)。
下面我们依次看作者的损失函数以及算法流程图具体是怎么设计的:
Loss function
损失函数使用最常见的交叉熵。
L θ ; T i i n = − 1 K ∑ k = 1 K log f θ ( x k i ) L_{\theta ;{T_i}}^{in} = - \frac{1}{K}\sum\limits_{k = 1}^K {\log } {f_\theta }\left( {{\bf{x}}_k^i} \right) Lθ;Tiin=−K1k=1∑Klogfθ(xki)
L θ ; T i o u t ( θ f a k e ) = − 1 M ∑ m = 1 M log ( 1 − f θ ( θ f a k e , m ) ) L_{\theta ;{T_i}}^{{out}}\left( {{\theta _{{fake}}}} \right) = - \frac{1}{M}\sum\limits_{m = 1}^M {\log } \left( {1 - {f_\theta }\left( {{\theta _{{fake},m}}} \right)} \right) Lθ;Tiout(θfake)=−M1m=1∑Mlog(1−fθ(θfake,m))
L θ ; T i ( D t r a i n i , θ f a k e ) = L θ ; T i i n + L θ ; T i o u t ( θ f a k e ) {L_{\theta ;{T_i}}}\left( {D_{{train}}^i,{\theta _{{fake}}}} \right) = L_{\theta ;{T_i}}^{in} + L_{\theta ;{T_i}}^{{out}}\left( {{\theta _{{fake}}}} \right) Lθ;Ti(Dtraini,θfake)=Lθ;Tiin+Lθ;Tiout(θfake)
where θ f a k e = ( θ f a k e , 1 , … θ f a k e , M ) {\theta _{{fake}}} = \left( {{\theta _{{fake},1}}, \ldots {\theta _{{fake},M}}} \right) θfake=(θfake,1,…θfake,M).
Algorithm
作者通过梯度更新生成对抗样本,也就是学习网络参数和伪样本交替进行:
θ i = θ − α ∇ θ L θ ; T i ( D t r a i n i , θ f a k e ) (1) {\theta ^i} = \theta - \alpha {\nabla _\theta }{L_{\theta ;{T_i}}}\left( {D_{{train}}^i,{\theta _{{fake}}}} \right)\tag{1} θi=θ−α∇θLθ;Ti(Dtraini,θfake)(1)
θ f a k e i = θ f a k e − β f a k e ⊙ s i g n ( − ∇ θ f a k e L θ i ; T i ( D t r a i n i , θ f a k e ) ) (2) \theta _{{fake}}^i = {\theta _{{fake}}} - {\beta _{{fake}}} \odot {\mathop{\rm sign}\nolimits} \left( { - {\nabla _{{\theta _{{fake}}}}}{L_{{\theta ^i};{T_i}}}\left( {D_{{train}}^i,{\theta _{{fake}}}} \right)} \right)\tag{2} θfakei=θfake−βfake⊙sign(−∇θfakeLθi;Ti(Dtraini,θfake))(2)
θ a d a p t i = θ − α ∇ θ L θ i ; T i ( D t r a i n i , ( θ f a k e , θ f a k e i ) ) (3) \theta _{{adapt}}^i = \theta - \alpha {\nabla _\theta }{L_{{\theta ^i};{T_i}}}\left( {D_{{train}}^i,\left( {{\theta _{{fake}}},\theta _{{fake}}^i} \right)} \right)\tag{3} θadapti=θ−α∇θLθi;Ti(Dtraini,(θfake,θfakei))(3)
( θ , θ f a k e , β f a k e ) ← ( θ , θ f a k e , β f a k e ) − γ ∇ ( θ , θ f a k e , β f a k e ) ∑ T i ∼ P ( T ) L ( D t e s t i ) (4) \left( {\theta ,{\theta _{{fake}}},{\beta _{{fake}}}} \right) \leftarrow \left( {\theta ,{\theta _{{fake}}},{\beta _{{fake}}}} \right) - \gamma {\nabla _{\left( {\theta ,{\theta _{{fake}}},{\beta _{{fake}}}} \right)}}\sum\limits_{{{T_i} \sim P(T)}} L \left( {D_{{test}}^i} \right)\tag{4} (θ,θfake,βfake)←(θ,θfake,βfake)−γ∇(θ,θfake,βfake)Ti∼P(T)∑L(Dtesti)(4)
其中,
L
(
D
t
e
s
t
i
)
=
−
1
Q
∑
q
=
1
Q
y
q
i
log
p
q
i
+
(
1
−
y
q
i
)
log
(
1
−
p
q
i
)
L\left( {D_{{test}}^i} \right) = - \frac{1}{Q}\sum\limits_{q = 1}^Q {y_q^i} \log p_q^i + \left( {1 - y_q^i} \right)\log \left( {1 - p_q^i} \right)
L(Dtesti)=−Q1q=1∑Qyqilogpqi+(1−yqi)log(1−pqi),
p
q
i
=
f
θ
a
d
a
p
t
i
(
x
q
i
)
p_q^i = {f_{\theta _{{adapt}}^i}}\left( {{\bf{x}}_q^i} \right)
pqi=fθadapti(xqi),
γ
>
0
\gamma > 0
γ>0 是元学习率。
(
θ
f
a
k
e
,
θ
i
f
a
k
e
)
({\theta _{{fake}}},{\theta ^i}_{{fake}})
(θfake,θifake) 是
θ
f
a
k
e
{\theta _{{fake}}}
θfake 和
θ
i
f
a
k
e
{\theta ^i}_{{fake}}
θifake的拼接操作(concatenation)。
元测试阶段
元测试阶段也就是微调过程和内层循环大同小异,看下作者是怎么叙述的:
p j ( x ) = [ f θ a d a p t j 1 ( x ) , … , f θ a d a p t j N ( x ) ] {p^j}(x) = \left[ {{f_{\theta _{{adapt}}^{j1}}}(x), \ldots ,{f_{\theta _{{adapt}}^{jN}}}(x)} \right] pj(x)=[fθadaptj1(x),…,fθadaptjN(x)]
Note that f θ a d a p t j n ( ⋅ ) {f_{\theta _{{adapt}}^{jn}}}( \cdot ) fθadaptjn(⋅) are binary classifiers, and the label 0 can be assigned if f θ a d a p t j n ( ⋅ ) < λ {f_{\theta _{{adapt}}^{jn}}}( \cdot ) < \lambda fθadaptjn(⋅)<λ, where λ \lambda λ is a threshold, while the label 1 is assigned otherwise, in the test phase. The threshold λ \lambda λ can be determined based on some criteria such as the true positive ratio (TPR), or simply set to 0.5 as a default value for binary classification.
很简单啦,大家自己读读看就好。
四、实验
Baselines. (i) ODIN with pretrained MAML (ii) ODIN with pretrained PN (iii) MAH with pretrained MAML. (iv) (N+1) classes with MAML without fake images ((N+1)-MAML) (v) (N+1) classes with MAML with
(
θ
f
a
k
e
,
θ
i
f
a
k
e
)
({\theta _{{fake}}},{\theta ^i}_{{fake}})
(θfake,θifake) ((N+1)-MAML*).
Task setting. Set the 5-shot data of one class in
D
t
r
a
i
n
{D_{train}}
Dtrain and set 50 samples in
D
t
e
s
t
{D_{test}}
Dtest, where 25 samples are drawn from seen classes.
Datasets. Omniglot, CIFAR-FS and MiniImageNet.
Evaluation criteria. (i) true positive rate (TPR). (ii) true negative rate (TNR).
T
N
R
=
T
N
/
(
T
N
+
F
P
)
TNR = TN/(TN + FP)
TNR=TN/(TN+FP).
部分实验结果如下,感兴趣的朋友可以在原论文找到更多细节。
最后来分析下作者认为伪样本为什么要是任务相关的:这里绿色是任务无关的伪样本,蓝色空心圈是任务相关的伪样本,红色是OOD样本。可以看到图(b)学到的边界更加sharp且错判概率更低。
五、总结
大家对于本文和小样本OOD问题有什么见解呢?有什么可以改进的想法欢迎在评论区留言讨论!