论文阅读《Matching Networks for One Shot Learning》

提出了一种基于度量学习和利用外部记忆来增强网络的方法,用一些有标签的 support set 来推测无标签的 query,是一种小样本的图像分类的方法。

18年的文章,可以说是小样本学习领域的经典文章。文章自己认为的创新有两面,一个是在模型层面提出 Matching Net,第二个是遵循了一个简单的机器学习准测:训练与测试的数据分布必须相近(这也算创新?)。提出了在 ImageNet 上的一个 benchmark。

文章提到了几个激发他们灵感的网络,大都有神经注意力机制,可微分并且使用一个外部记忆来解决任务。可以理解为P\left (B\mid A\right ),其中 A 为 support set,B 为 query set。

Matching Network

实现 f()、g() 有两种办法:CNN 和 LSTM。

  • 神经注意力机制(CNN)

对于每一个 support set S(small),训练一个分类器 Cs,对每一个输入 x(可以是图片也可以是句子),输出其可能的概率分布:

xi 和 yi 是 support set 中的数据,\hat{x} 和 \hat{y} 是 query set 中的数据,P() 使用模型来实现的,a(文中形容为 attention mechanism 不知该这么理解)是一个核函数。这个公式可以看作是对 support set 中数据的线性组合,这个公式类似于一个核密度估计。选用合适的距离函数和常数,对于离 \hat{x} 第 b 近的 xi,a() 的结果都是0,这时这个公式等价于 “k-b” 近邻。

另一个解释的角度是把这个公式看作一个 hash 表,每一个 xi 都有对用的响应 yi。就像有了可以联想的记忆,给定一个输入我们可以检索出 support set 中与之类似的案例。

因此分类器 Cs 的形式非常灵活,并且很容易就可适应到新的 support set 上。query 应能与 support set 中相一致的图片类别相对齐。

这个 a() 最简单的实现方法是使用余弦距离和 softmax,当然也可以是用别的方法来计算需预测类别的输入 x' 与所有给定标签的输入 xi 的关系。

  • 外部记忆(LSTM)

先来简单说一下为何作者提出外部记忆。作者认为上述的余弦注意力定义的时候,(输入任务S中)每个已知标签的输入 xi 通过CNN后的embedding,也就是 g(xi) 是独立的,前后没有关系,然后与 f(x') 进行逐个对比,这看起来就有点简单粗暴,没有考虑到输入任务S改变embedding x' 的方式,也就是 f() 应该是受 S 影响的。(x' 为 \hat{x}

利用 f 和 g 来表征 support 和 query 的嵌入,文章认为 support 应该不仅影响 g 表征的方式,也应该能够影响 f 表征 query 图片的方式。g 的表征方式用双向的 LSTM 来实现,使其能够充分编码 support set 内的信息,support 影响 f 表征 query 的方式也是使用 LSTM:

即先将任务 S 中所有图片 xi(假设有 K 个)和目标图片 x'(假设有1个)全部通过 CNN 网络获得他们的 embedding,然后将这 K+1 个向量全部输入到双向 LSTM 中,获得 K+1 个输出。然后使用余弦距离判断前 K 个输出中每个输出与最后一个输出之间的相似度,相似度最高的 xi 的标签就是 x' 的预测值。

LSTM 不展开。

  • 训练策略

对一个任务 T 和带标签的数据 L,最多包含5类,每一类最多有5张图片。例如 L 可以是 {cats, dogs},需要抽样出 B 个 batch 中的 support set S,S 是带标签的包含 cat 和 dog 的图片。Matching Net 的任务是减少在 B 上对 support set 进行分类的损失,其损失函数如下:

This is a form of meta-learning since the training procedure explicitly learns to learn from a given support set to minimise a loss over a batch.

训练完成后,在 novel 类别中再抽样出 S' 和 T',再调用 θ 完成分类任务,当 T' 与 T 相差较大时效果不好。

Experiments

在 L 上完成训练,在 L' 上完成测试。也测试了经过 L’ 微调后的性能。

Note that, even when fine tuning, the setup is still one-shot, as only a single example per class from L' is used.

在两个数据集 Omniglot 和 ImageNet 上的 N-way k-shot 分类任务。

  • Omniglot

  • ImageNet

分别测试了在三个子数据集(randImageNet、dogsImageNet、miniImageNet)和 ImageNet 上的性能。还完成了对上面介绍的外部记忆的消融实验,原文中称作 Full Contextual Embeddings (FCE)。

randImageNet 是在训练集中抽了118个类的118张图片,仅在这些图片上做测试,记作 Lrand。

dogsImageNet 是在除了包含 dog 的类别上训练,在包含 dog 的类别上做测试。

在 ImageNet 上采用的 baseline 是 Inception,用 Inception 的分类器初始化了 Matching Net 中的 f 和 g,再以 5-way 1-shot 的方式在训练集进行训练。

这里的 Inception 是 trained to classify on all classes except those in the test set of classes (for randImageNet) or those concerning dogs (for dogsImageNet)。

However, on the much more challenging Ldogs subset, our model degrades by 1%. We hypothesize this to the fact that the sampled set during training, S, comes from a random distribution of labels (from ≠ Ldogs), whereas the testing support set S' from Ldogs contains similar classes, more akin to fine grained classification. Thus, we believe that if we adapted our training strategy to samples S from fine grained sets of labels instead of sampling uniformly from the leafs of the ImageNet class tree, improvements could be attained. We leave this as future work.

Conclusion

如果训练网络进行 one-shot,那么 one-shot 会容易得多。

神经网络中的非参数结构使网络更容易记忆和适应相同任务中的新数据集。

附加

  • 核密度估计

是对直方图的扩展,针对直方图存在的不连续、受终点影响大、受间隔影响大等缺点进行的改进。将每个数据点替换成相应的核函数(比如高斯核函数),再将所有点的值叠加再归一化,即得到核密度估计曲线。带宽选择也十分重要。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值