few shot learning-小样本学习入门

参考链接:
https://zhuanlan.zhihu.com/p/85277741
https://blog.csdn.net/weixin_37589575/article/details/92801610
https://blog.csdn.net/weixin_40123108/article/details/89003325

基本概念

小样本学习(few shot learning,FSL)可以看做每个类别样本数目远远小于类别数目,也就是说每个类别仅仅只有几个样本可供训练。

支持集(support set):包含着少量标注的样本。

查询集(query set):包含着未标注的样本,和支持集的类别空间一致。

N-way K-shot: 表示支持集包含着 K K K个类别,每个类别有 N N N个标注样本

episode training: 训练在训练集上进行,采用episode training机制。此机制和test阶段是一样的,即每一个episode 随机采样 K K K个类别,每个类别采样 N N N个标注样本构成支持集,每个类别的剩余一小部分样本构成查询集,怎样将支持集和查询集样本嵌入到合适的空间,使之类内的样本相似度高,类间的样本相似度低,是一个构建的问题。

数据集划分:按照约定,数据集会被划分为训练集,验证集和测试集,三个集之间的类别是不相交的。

test stage:在测试阶段,我们当前拿到的测试集的label当然是已知的。此时也进行类似episode training的机制,每一个episode从测试集中随机采样 K K K个类别,每个类别采样 N N N个标注样本构成支持集,此时支持集的label是已知的也是可用的,之后每个类别的剩余一小部分样本构成查询集,这些样本是需要模型分类的,label是未知的,通过模型计算得到查询集样本的accuracy,此时就模拟了我们在每个类别小样本label已知的情况下,去预测其他相同label空间样本。
在测试阶段我们要进行多次episode,得到多个accuracy的值,所以此时约定取最终的accuracy为多个episode的平均值,并report 95%的置信区间。

流行的方法

1.数据增强和正则化

这类方法比较简单直接,数据增强是针对样本数量过少来增加样本,小样本学习设置下模型非常容易陷入过拟合,因此数据增强和正则化都能作为正则化来防止过拟合。

2.元学习(Meta Learning)

是目前主流的解决方案,首先介绍什么是元学习。
概念介绍: 元学习的目标是利用已经学到的知识来解决新的问题。这也是基于人类学习的机制,我们学习都是基于已有知识的,而不像深度学习一样都是从 0 开始学习的,也称为“学会学习” (Learning to learn)。
元学习将学习(训练)的任务称为meta-training task,新的任务称为meta-test task。在小样本学习中, meta training 阶段将数据集分解为不同的 meta task,去学习类别变化的情况下模型的泛化能力,在 meta testing 阶段,面对全新的类别,不需要变动已有的模型,就可以完成分类。
例如分类 MiniImagenet,其中有 100 个类,我们用其中 60 个类来学习先验知识,20个做 validation,剩余 20 个做测试。注意我们测试的 20 个类和前面 80 个类是完全不同的,也就是新的类、新的概念、新的问题,并且这 20 个类每个类只有很少的几张已知label的图片 (few-shot 问题)!然后前面的 80 个类用来训练模型和确定超参数,也就是学习帮助我们解决新问题的先验知识。

2.1 学习微调 (Learning to Fine-Tune)

这种方法已被广泛地应用。首先给定一个预训练的基础网络,通过含有丰富标签的大规模数据集训练得到的,比如imagenet,然后根据特定领域的数据进行微调,通常在少量的样本上训练一下就可以得到不错的效果。
经典的工作就是MAML(Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks):MAML 的思想是学习一个 初始化参数 (initialization parameter),这个初始化参数在遇到新的问题时,只需要使用少量的样本 (few-shot learning) 进行几步梯度下降就可以取得很好地效果

2.2 度量学习(metric learning)

核心思想:学习一个 embedding 函数,将输入空间(例如图片)映射到一个新的嵌入空间,在嵌入空间中有一个相似性度量来区分不同类。我们的先验知识就是这个 embedding 函数,在遇到新的 task 的时候,只将需要分类的样本点用这个 embedding 函数映射到嵌入空间里面,使用相似性度量比较进行分类。
这里主要讲解两个代表性的方法:Matching Network和

2.2.1 Matching Network

在这里插入图片描述
inference过程

网络结构上图所示, g θ g_{\theta} gθ f θ f_{\theta} fθ分别是支持集和查询集样本的编码器,将图片embed为向量,其中 x ^ \hat{x} x^为查询样本, x i x_{i} xi是支持样本, a ( x ^ , x i ) a(\hat{x},x_{i}) a(x^,xi)从余弦相似度的角度衡量了查询样本与每个支持样本的相关度,进行归一化是因为下面要进行加权的预测样本的label
在这里插入图片描述
这里 y i y_{i} yi是支持样本 x i x_{i} xi的label,利用相似度为权重,进行label的加权求解:
在这里插入图片描述
训练过程
采用episodic training的方式,训练的时候首先采样支持集 S S S,再在支持集样本空间下采样一个batch的数据作为查询集,batch中的每一个样本依次与 S S S计算分类误差,也就是朴素的交叉熵。

2.2.2 Prototypical Networks (原型网络)

在这里插入图片描述
过程比较简单,采用episodic training,每一个episode,将每一个类的支持集的样本representation求解平均作为该类别的原型表示(prototypical representation),对于查询集的样本,将平方欧氏距离作为与类别prototypical representation的相似度度量,求解误差时,拉近查询样本表示与对应的类的原型表示,远离其他类的原型表示。

数据集

1.Omniglot

介绍:它一共包含1623 类手写体,每一类中包含20 个样本。其中这 1623 个手写体类来自 50 个不同地区(或文明)的 alphabets,如:Latin 文明包含 26 个alphabets,Greek 包含 24 个alphabets。
划分:train的是 964 类(30个地区的字母),用于test的是 659 类 (20个地区的字母)。

©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页