论文分享——OOD-MAML: Meta-Learning for Few-Shot Out-of-Distribution Detection and Classification

本文介绍了一种名为OOD-MAML的方法,用于解决小样本学习中的OOD样本检测问题。通过扩展N-way分类任务为N+1-way,利用元学习策略,该方法在完成分类的同时能有效识别分布外样本。实验结果显示,OOD-MAML在小样本和OOD样本区分上表现良好,强调了任务相关的伪样本对提高性能的重要性。
摘要由CSDN通过智能技术生成

针对小样本的OOD检测方法


论文地址:OOD-MAML (于2020年发表在NIPS。)

前言

上一次集中分享了集中MAML的变体算法,包括Reptile,DKT,MTNET,CAVIA,TAML,Pruning等。今天让我们一起来看一篇新的改进思路——分布外样本检测和分类的MAML算法。其核心思路是:把N way K shot分类问题扩展为N+1 way K shot问题,增加的一类是未知的OOD样本。把一个N way的任务划分为N个子任务,每个任务判断是否属于该类。若不属于任何一类,则说明该样本是分布外样本。 该思路巧妙、经典、值得借鉴,一起来看看吧。


一、摘要

本文首先提出小样本学习面临的两类challenge,其一是缺少从已知类中学习训练数据的分布,该问题可以用元学习的方法解决;其二是训练时缺少分布外样本,针对该问题,作者提出了利用OOD-MAML的方法解决。
本文贡献:1.在完成小样本分类问题的同时可以检测分布外样本。2.将N way分类任务转换为N个子任务,可以处理训练和测试N不同的情况。


二、相关工作

作者发现,深度神经网络(DNN)在面对分布外样本时会产生过高的置信度,也就是误判概率高的问题。面对这样的问题,一些做法是提供不确定性估计(Uncertainty qualification, UQ)。一些思路包括使用Softmax scores、改进的Softmax scores以及马氏距离 (MAH)方法等解处理分布外检测问题。


三、方法论

作者首先介绍了元学习的基本设置,这里就略去不讲了,有疑问的可以看我系列文章:基于MAML的改进方法总结

OOD-MAML的核心思想

这里我用一个N=3的样本分类任务举例。假设任务是猫、狗和马的分类任务,用one-hot encoding独热编码来表示,即:猫:(1,0,0),狗:(0,1,0),马:(0,0,1),我们这里多加一类:(0,0,0)。如果机器判断为(0,0,0),则说明是OOD样本。进而我们把N=3的样本分类任务变为三个子任务,即第一个任务分辨图片是不是猫,是为1否为0,接着重复上述操作判断是否为狗和马。若上述三个都判断为0就是OOD样本了。综上,我们用N=4的样本分类任务解决N=3的样本分类及OOD检测问题。

元训练阶段

通过前文叙述,相信大家已经很好地理解了上述过程。下面通过流程图进一步分析,注意,本文的任务设定是每个任务只包含一类样本,只判断是或不是该类(Task setting in OOD-MAML: We construct D t r a i n ∈ D m e t a − t r a i n {D_{train}} \in {D_{meta - train}} DtrainDmetatrain to contain K K K examples of one known class.):
在这里插入图片描述

这里需要注意的是,在元训练阶段OOD样本是用噪声合成的伪样本,而在元测试阶段是真正的分布外样本。 作者阐述该观点的方式被我有失偏颇地总结如下:

  1. 由于在训练阶段并不知道OOD样本的特征,因此在没有先验的情况下用噪音手动生成伪样本比真样本合适,否则容易导致分类器有偏;
  2. 如果使用任务无关的伪样本,这意味着分类器并不能更好地学到对当前任务敏感的边界(sharp decision boundary),所以伪样本是和当前任务相关的,需要因任务而适应(作者实验中做图再次说明)。

下面我们依次看作者的损失函数以及算法流程图具体是怎么设计的:

Loss function

损失函数使用最常见的交叉熵。

L θ ; T i i n = − 1 K ∑ k = 1 K log ⁡ f θ ( x k i ) L_{\theta ;{T_i}}^{in} = - \frac{1}{K}\sum\limits_{k = 1}^K {\log } {f_\theta }\left( {{\bf{x}}_k^i} \right) Lθ;Tiin=K1k=1Klogfθ(xki)
L θ ; T i o u t ( θ f a k e ) = − 1 M ∑ m = 1 M log ⁡ ( 1 − f θ ( θ f a k e , m ) ) L_{\theta ;{T_i}}^{{out}}\left( {{\theta _{{fake}}}} \right) = - \frac{1}{M}\sum\limits_{m = 1}^M {\log } \left( {1 - {f_\theta }\left( {{\theta _{{fake},m}}} \right)} \right) Lθ;Tiout(θfake)=M1m=1Mlog(1fθ(θfake,m))
L θ ; T i ( D t r a i n i , θ f a k e ) = L θ ; T i i n + L θ ; T i o u t ( θ f a k e ) {L_{\theta ;{T_i}}}\left( {D_{{train}}^i,{\theta _{{fake}}}} \right) = L_{\theta ;{T_i}}^{in} + L_{\theta ;{T_i}}^{{out}}\left( {{\theta _{{fake}}}} \right) Lθ;Ti(Dtraini,θfake)=Lθ;Tiin+Lθ;Tiout(θfake)
where θ f a k e = ( θ f a k e , 1 , … θ f a k e , M ) {\theta _{{fake}}} = \left( {{\theta _{{fake},1}}, \ldots {\theta _{{fake},M}}} \right) θfake=(θfake,1,θfake,M).

Algorithm

作者通过梯度更新生成对抗样本,也就是学习网络参数和伪样本交替进行:

θ i = θ − α ∇ θ L θ ; T i ( D t r a i n i , θ f a k e ) (1) {\theta ^i} = \theta - \alpha {\nabla _\theta }{L_{\theta ;{T_i}}}\left( {D_{{train}}^i,{\theta _{{fake}}}} \right)\tag{1} θi=θαθLθ;Ti(Dtraini,θfake)(1)
θ f a k e i = θ f a k e − β f a k e ⊙ s i g n ( − ∇ θ f a k e L θ i ; T i ( D t r a i n i , θ f a k e ) ) (2) \theta _{{fake}}^i = {\theta _{{fake}}} - {\beta _{{fake}}} \odot {\mathop{\rm sign}\nolimits} \left( { - {\nabla _{{\theta _{{fake}}}}}{L_{{\theta ^i};{T_i}}}\left( {D_{{train}}^i,{\theta _{{fake}}}} \right)} \right)\tag{2} θfakei=θfakeβfakesign(θfakeLθi;Ti(Dtraini,θfake))(2)
θ a d a p t i = θ − α ∇ θ L θ i ; T i ( D t r a i n i , ( θ f a k e , θ f a k e i ) ) (3) \theta _{{adapt}}^i = \theta - \alpha {\nabla _\theta }{L_{{\theta ^i};{T_i}}}\left( {D_{{train}}^i,\left( {{\theta _{{fake}}},\theta _{{fake}}^i} \right)} \right)\tag{3} θadapti=θαθLθi;Ti(Dtraini,(θfake,θfakei))(3)
( θ , θ f a k e , β f a k e ) ← ( θ , θ f a k e , β f a k e ) − γ ∇ ( θ , θ f a k e , β f a k e ) ∑ T i ∼ P ( T ) L ( D t e s t i ) (4) \left( {\theta ,{\theta _{{fake}}},{\beta _{{fake}}}} \right) \leftarrow \left( {\theta ,{\theta _{{fake}}},{\beta _{{fake}}}} \right) - \gamma {\nabla _{\left( {\theta ,{\theta _{{fake}}},{\beta _{{fake}}}} \right)}}\sum\limits_{{{T_i} \sim P(T)}} L \left( {D_{{test}}^i} \right)\tag{4} (θ,θfake,βfake)(θ,θfake,βfake)γ(θ,θfake,βfake)TiP(T)L(Dtesti)(4)

其中, L ( D t e s t i ) = − 1 Q ∑ q = 1 Q y q i log ⁡ p q i + ( 1 − y q i ) log ⁡ ( 1 − p q i ) L\left( {D_{{test}}^i} \right) = - \frac{1}{Q}\sum\limits_{q = 1}^Q {y_q^i} \log p_q^i + \left( {1 - y_q^i} \right)\log \left( {1 - p_q^i} \right) L(Dtesti)=Q1q=1Qyqilogpqi+(1yqi)log(1pqi),
p q i = f θ a d a p t i ( x q i ) p_q^i = {f_{\theta _{{adapt}}^i}}\left( {{\bf{x}}_q^i} \right) pqi=fθadapti(xqi), γ > 0 \gamma > 0 γ>0 是元学习率。 ( θ f a k e , θ i f a k e ) ({\theta _{{fake}}},{\theta ^i}_{{fake}}) (θfake,θifake) θ f a k e {\theta _{{fake}}} θfake θ i f a k e {\theta ^i}_{{fake}} θifake的拼接操作(concatenation)。

元测试阶段

元测试阶段也就是微调过程和内层循环大同小异,看下作者是怎么叙述的:

p j ( x ) = [ f θ a d a p t j 1 ( x ) , … , f θ a d a p t j N ( x ) ] {p^j}(x) = \left[ {{f_{\theta _{{adapt}}^{j1}}}(x), \ldots ,{f_{\theta _{{adapt}}^{jN}}}(x)} \right] pj(x)=[fθadaptj1(x),,fθadaptjN(x)]
Note that f θ a d a p t j n ( ⋅ ) {f_{\theta _{{adapt}}^{jn}}}( \cdot ) fθadaptjn() are binary classifiers, and the label 0 can be assigned if f θ a d a p t j n ( ⋅ ) < λ {f_{\theta _{{adapt}}^{jn}}}( \cdot ) < \lambda fθadaptjn()<λ, where λ \lambda λ is a threshold, while the label 1 is assigned otherwise, in the test phase. The threshold λ \lambda λ can be determined based on some criteria such as the true positive ratio (TPR), or simply set to 0.5 as a default value for binary classification.

很简单啦,大家自己读读看就好。


四、实验

Baselines. (i) ODIN with pretrained MAML (ii) ODIN with pretrained PN (iii) MAH with pretrained MAML. (iv) (N+1) classes with MAML without fake images ((N+1)-MAML) (v) (N+1) classes with MAML with ( θ f a k e , θ i f a k e ) ({\theta _{{fake}}},{\theta ^i}_{{fake}}) (θfake,θifake) ((N+1)-MAML*).
Task setting. Set the 5-shot data of one class in D t r a i n {D_{train}} Dtrain and set 50 samples in D t e s t {D_{test}} Dtest, where 25 samples are drawn from seen classes.
Datasets. Omniglot, CIFAR-FS and MiniImageNet.
Evaluation criteria. (i) true positive rate (TPR). (ii) true negative rate (TNR). T N R = T N / ( T N + F P ) TNR = TN/(TN + FP) TNR=TN/(TN+FP).
部分实验结果如下,感兴趣的朋友可以在原论文找到更多细节。
实验1
最后来分析下作者认为伪样本为什么要是任务相关的:这里绿色是任务无关的伪样本,蓝色空心圈是任务相关的伪样本,红色是OOD样本。可以看到图(b)学到的边界更加sharp且错判概率更低。
在这里插入图片描述


五、总结

大家对于本文和小样本OOD问题有什么见解呢?有什么可以改进的想法欢迎在评论区留言讨论!

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
### 回答1: Out-of-distribution是指在模型训练时未曾出现过的数据分布,也称为“未知数据”。在模型面对未知数据时,其预测结果可能会出现误差或不确定性。因此,对于模型的鲁棒性和泛化能力的提升,需要对out-of-distribution数据进行有效的识别和处理。 ### 回答2: out-of-distributionOoD)是指模型在测试阶段遇到了其训练数据集之外的样本或类别。当模型只使用特定的数据集进行训练时,它可能无法处理那些与训练数据不同的输入。这些新的样本可能是在颜色、形状、大小等方面与训练数据有所不同,也可能属于未在训练数据中出现过的类别。 遇到OoD样本的问题是模型的泛化能力不足。模型在训练数据中表示和学习的特征可能过于特定,无法推广到训练数据集之外的样本。这可能导致模型的预测不准确或不可靠。 为了解决OoD问题,有几种方法可以采取。一种常见的方法是收集更多来自OoD分布的样本,并将其添加到训练数据中,以使模型能够更好地学习如何处理这些新样本。另一种方法是使用一些先验知识或规则,对OoD样本进行检测和筛选,以避免对其进行错误预测。 同时,一些研究者提出了一些用于检测OoD样本的新颖性评估方法。这些方法通过利用模型在训练样本和OoD样本上的输出差异来判断一个样本是否属于OoD类别。这种方法可以帮助我们识别OoD样本,并采取相应的措施,以提高模型的泛化性能。 综上所述,解决out-of-distribution问题是训练一个具有较强泛化能力的模型的重要步骤。只有当模型能够有效处理新的样本和未见过的类别时,才能提高模型的可靠性和适用性。 ### 回答3: "out-of-distribution"是指数据集中没有包含的数据样本或样本类别。在机器学习和深度学习中,数据集通常用于训练和测试模型的性能。然而,在现实世界中,我们会遇到无法准确分类的新数据,这些数据就属于"out-of-distribution"。这可能是因为这些数据具有与训练数据不同的特征,或者因为数据集的覆盖范围有限。 "out-of-distribution"的出现可能会对模型的性能和鲁棒性产生负面影响。由于模型没有前面没有见过这些类型的数据,它可能会对其进行错误的分类或给出不确定的预测结果。这种情况在实际应用中特别重要,因为我们希望模型能够在各种不同的情况下表现得可靠和准确。 为了解决"out-of-distribution"问题,一种常见的方法是通过收集更多具有代表性的训练数据来增加数据集的覆盖范围。这样模型可以更好地学习不同类型的数据特征,并提高对"out-of-distribution"数据的泛化能力。另外,使用先进的模型架构和优化算法也可以增强模型的鲁棒性。 除了增加训练数据和改进模型架构外,还可以使用一些检测方法来识别"out-of-distribution"的样本。这些方法可以根据模型的置信度、预测熵或数据分布等特征来判断样本是否属于训练集之外的数据。这些方法可以帮助我们发现并处理那些可能造成模型失效的"out-of-distribution"数据。 总之,"out-of-distribution"是指在训练数据之外的数据样本或样本类别。对于机器学习和深度学习任务,了解和解决"out-of-distribution"问题是提高模型性能和鲁棒性的关键。通过增加训练数据、改进模型架构和使用检测方法,我们可以减少"out-of-distribution"带来的负面影响。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

keive13

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值