Prototypical Networks for Few-shot Learning(原型神经网络)

1 总体概述

文中提出了一种简单的方法,称为原型网络(Prototypical Networks),用于小样本学习。该方法基于这样一个理念:我们可以通过神经网络学习的表示空间中的示例均值来表示每个类别。我们通过使用 episodic training 特别训练这些网络,使其在小样本设置下表现良好。这种方法比最近的元学习方法简单且高效,并且即使没有为匹配网络(Matching Networks)开发的复杂扩展,仍然能产生最先进的结果(尽管这些扩展也可以应用于原型网络)。我们展示了通过仔细选择距离度量和修改 episodic 学习过程,可以大幅提高性能。我们还进一步展示了如何将原型网络推广到零样本设置,并在 CUB-200 数据集上取得了最先进的结果。未来的研究方向自然是利用其他类型的 Bregman 散度,而不仅仅是平方欧几里得距离,以对应于超出球形高斯分布的类别条件分布。我们对此进行了初步探索,包括为每个类别学习每个维度的方差。但这并未带来任何经验上的提升,这表明嵌入网络本身具有足够的灵活性,而无需每个类别额外拟合参数。总的来说,原型网络的简单性和有效性使其成为小样本学习的一个有前景的方法。

方法

2.1 算法流程图

在这里插入图片描述这里其实求出一个损失,就让样品离自己的类别原型距离近,离其他的样本原型距离远。

中间部分过于复杂这里省略

2.6 设计的细节

距离公式的选择:欧几里得公式:关于距离公式的选择可以使用余弦相似度也可以使用欧几里得公式进行。然后作者发现欧几里得公式能够同时提升Prototypical Networks和Marching Networks 的分类准确率。
Episode 的构建:Episode主要类别数 N C N_C NC和每个类别的支持集数量 N S N_S NS构。如果是进行5类别分类、一样本迁移的实验那么 N C = 5 N_C=5 NC=5 N S = 1 N_S=1 NS=1然后作者研究了这两个值对实验效果的影响。发现训练的时候要尽量使用更多的 N C N_C NC这样能提升测试的时候的效果、对于 N S N_S NS要尽量让训练的时候的 N S N_S NS和测试的时候一样。

2.7 零样本迁移的设计

2.7.1 补充知识

少样本迁移:训练的时候就给少量的信息(比如说每一个类就给一张图片),然后当给出训练集之外的图片(类别是跟上面训练的时候的类别是一样的)的时候,模型需要通过训练的时候的少量信息判别图片是什么类别的图片。
零样本迁移:这个跟上面的区别就是,使用模型的时候需要模型去判断一个未见过的类。举个例子,比如说模型经过训练后能够识别出狮子这个类,但是我通过给模型一个语义信息“老虎像狮子,身上有黑色条纹”,模型通过本来训练时候的信息和语义信息判断出一个新类别:老虎。

2.7.2 文中设计零样本迁移

在这里插入图片描述上图是少样本和零样本的总览图。零样本本来就要求模型在原来对于某一类别没有经过训练,通过附加一定的额外信息(语义信息等、文中的元数据信息) c = c 0 + v c = c_0 + v c=c0+v,这里的c是原来 c 0 c_0 c0加上v所零样本迁移出的类别。v是元数据它可以提前设置好或者是从原始文本中获取!由于模型本身需要通过跟支持集生成的原型(也就是少样本中的c_1、c_2、c_3)对比距离才能产生分类结果,由于零样本迁移不提供支持集,因此作者直接通过v生成类别原型,在文中简单的定义新的类别原型为 c k = g θ ( v k ) c_k=g_\theta(v_k) ck=gθ(vk),由于元数据v和查询集x是来自不同域的输入(往往v是文本类型数据,x是图片数据)因此作者发现将编码类别原型( g g g)限制成一个单位,更有利于区分。对于对x的编码( f f f)则不适合,原因是x中往往包含丰富的信息,因此不适合缩小。

3 实验

3.1 实验设置

少-shot学习实验: 他们在 Omniglot 数据集和 miniImageNet 版本的 ILSVRC-2012 数据集上进行了实验。具体来说,他们使用了 Ravi 和 Larochelle 提出的划分方法来处理这些数据集。这两个数据集广泛用于少-shot学习任务:

  • Omniglot:一个包含手写字符的图像数据集,通常用于少-shot学习的基准测试。 miniImageNet:从
  • ILSVRC-2012(ImageNet)中提取的子集,同样用于少-shot学习实验。

零-shot学习实验: 他们在 Caltech UCSD 鸟类数据集(CUB-200 2011) 上进行了零-shot实验。这个数据集包含200种鸟类的图像,是图像分类任务中常用的零-shot学习数据集之一

数据集:他们使用 Omniglot 数据集进行实验,这个数据集包含1623个手写字符,来自50种不同的字母表,每个字符有20个不同人书写的例子。他们将图像缩放到 28×28,并通过旋转字符类(每次90度)进行数据增强,训练集中包含1200个字符及其旋转版本(共4800类),其余的用于测试。

  • 每个类别对应一个字符:Omniglot 包含来自 50 种不同字母表的 1623 个字符。每个字符作为一个独立的类别。

  • 每个类别有 20 个样本:对于每个字符类别,数据集中提供了 20 个不同的人手写的版本。这些样本用来表示同一个字符的不同变体。

  • 数据增强后的类别:在实验中,为了增加样本量,他们将每个字符通过旋转(每次旋转 90 度)进行数据增强。因此,经过旋转后,每个字符变成 4 个类(原始方向和3个旋转方向),大大扩展了训练和测试集中的类别数量。

模型结构:模型采用四层卷积网络,每层包含64个滤波器的3×3卷积、批量归一化、ReLU激活和2×2最大池化层,输出为64维嵌入空间。同样的编码器用于支持集和查询集。模型使用Adam优化器进行训练,初始学习率为0.001,并在每2000轮将学习率减半。

3.2 实验1

基线:Neural Statistician、Meta-Learner LSTM、MAML 及fine-tuned Matching Networks、
non-fine-tuned Matching Networks
实验结果
在这里插入图片描述结果就是本文方法达到了最优的水平
可视化:为了更加形象的表示出匹配过程和结果,作者展示了原型网络学到的嵌入空间的一个t-SNE的一个可视化示例。
在这里插入图片描述可视化中展示了测试集中 Tengwar 字母表的一个子集,类原型用黑色表示。多个被错误分类的字符用红色高亮显示,并通过箭头指向其正确的类原型。
如何去实现这个可视化?

3.2 在miniImageNet上的实验结果

实验设置:在实验中,使用了 Ravi 和 Larochelle 提出的划分方法,将 100 个类别划分为 64 个训练类、16 个验证类和 20 个测试类,训练主要在 64 个类别上进行,验证集用于监控泛化性能。

网络架构:使用与 Omniglot 数据集相同的四层嵌入网络,但由于 miniImageNet 图像尺寸较大,嵌入空间为 1,600 维。

训练过程:使用 1-shot 任务时是 30-way 训练,5-shot 任务时是 20-way 训练。每个类别的查询点固定为 15 个,训练和测试时的 shot 数保持一致。

基线比较:结果与 Ravi 和 Larochelle 报告的基线进行比较,包括最近邻方法和未微调的 Matching Networks 以及 Meta-Learner LSTM。实验结果表明,原型网络在 5-shot 任务上的准确率显著优于其他先进方法。

实验结果
在这里插入图片描述

3.3 距离度量方式和每轮训练类数量对原型网络和 Matching Networks 性能的影响

结果显示20-way 的准确率高于 5-way,并推测 20-way 分类的难度增加有助于网络更好地泛化,因为它迫使模型在嵌入空间中做出更精细的决策。此外,使用欧氏距离相比余弦距离显著提高了性能。这种效果在原型网络中更加明显,因为计算类原型时将支持点的嵌入均值更自然地适应欧氏距离,而余弦距离不是 Bregman 散度。

零样本迁移实验

在 Caltech-UCSD 鸟类(CUB)200-2011 数据集 [34] 上进行了实验。CUB 数据集包含 11,788 张图像,涉及 200 种鸟类。我们严格遵循 Reed 等人 [25] 的数据准备过程,使用他们的划分将类分为 100 个训练集、50 个验证集和 50 个测试集。对于图像,我们提取了 1,024 维特征,这些特征是通过将 GoogLeNet [31] 应用于原始图像及其水平翻转图像的中间、左上、右上、左下和右下裁剪部分获得的。在测试时,我们仅使用原始图像的中间裁剪部分。对于类的元数据,我们使用 CUB 数据集中提供的 312 维连续属性向量,这些属性编码了鸟类的各种特征,如颜色、形状和羽毛图案。

我们在 1,024 维图像特征和 312 维属性向量的基础上学习了一个简单的线性映射,以生成 1,024 维的输出空间。我们发现对类原型(嵌入的属性向量)进行单位长度归一化很有帮助,因为属性向量与图像来自不同的领域。训练情境是以 50 个类和每类 10 张查询图像构建的。嵌入通过固定学习率为 1 0 − 4 10^{-4} 104 和权重衰减为 1 0 − 5 10^{-5} 105的 Adam 优化算法进行优化。使用验证损失的提前停止来确定在训练和验证集上重新训练的最佳轮数。
在这里插入图片描述
表 3 显示,与使用属性作为类元数据的方法相比,我们取得了最先进的结果。我们的原型网络超越了 Synthesized Classifiers,并且与 Zhang 和 Saligrama [36] 的结果在误差条范围内,同时我们的方式比两者都简单。
此外,我们还进行了另一组具有更强类元数据的零样本实验。我们使用预训练的 Char CNN-RNN 模型 [25] 为每个 CUB-200 类提取了 1,024 维元数据向量,然后按照上述相同的程序训练零样本原型网络,不过我们使用了通过验证准确率选择的 512 维输出嵌入。我们的测试准确率达到了 58.3%,而 DS-SJE [25] 使用 Char CNN-RNN 模型获得的准确率为 54.0%。此外,我们的结果超过了 DS-SJE 使用更强的 Word CNN-RNN 类元数据表示所取得的 56.8% 准确率。综合来看,这些零样本分类结果表明,我们的方法具有足够的通用性,即使在数据点(图像)与类(属性)来自不同领域的情况下也能应用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

代码飞速跑

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值