Matching Networks for one shot learning阅读笔记

Matching Networks for one shot learning阅读笔记

few/one-shot learning

       顾名思义,就是通过少量的样本去训练和学习。
       我们人类非常善于通过极少量的样本去识别一个新事物,比如小孩子只需要书中的一张图片就可以认识什么是“长颈鹿” 。
       但是现在的主流的深度学习技术需要大量的数据来训练一个好的模型。比如典型的MNIST分类问题,其中包含60000个样本,十个类别(0-9是个数字),但是我们想一下我们人类自己,我们区分 0 到 9 的数字显然不需要这么多数据!所以当训练集中每个类别的数据较少时,机器如何进行高效的学习呢?就就是小样本要解决问题。
       形式化来说,few-shot 的训练集中包含了很多的类别,每个类别中有多个样本。在训练阶段,会在训练集中随机抽取 N 个类别,每个类别 K 个样本(总共 N * K个数据,作为模型的支撑集support set)输入;再从这N个类中剩余的数据中抽取一批(batch)样本作为模型的预测对象(batch set 或 query set)。即要求模型从 N*K 个数据中学会如何区分这 N个类别,这样的任务被称为 N-way K-shot 问题。其中K值一般较小,当K=1时,即变成了One-shot。

LSTM

       长短期记忆网络(LSTM, Long Short-Term Memory),是一种循环神经网络(RNN)的一个变种模型,继承了大部分RNN模型的特性,解决一般的RNN存在的长期依赖、梯度消失、梯度爆炸等问题。LSTM非常适合用于处理与时间序列高度相关的问题,例如机器翻译、对话生成、编码、解码等。
       其过程可以概括为,通过对这张图是包含了三个连续循环结构的LSTM,即三个time step,其中每个循环结构有两个输入,分别是上一时刻的隐藏层状态 h t − 1 h_{t-1} ht1和当前输入 x t x_t xt
       经典LSTM结构中有3个门和一个细胞状态:(1)遗忘门 (2)输入门 (3)输出门 (4)细胞状态 C t C_t Ct
       细胞状态中信息的遗忘和记忆新的信息使得有用的信息得以传递,而无用的信息被丢弃,并在每个时间步都会输出隐层状态 h t h_t ht。其中遗忘,记忆与输出是通过遗忘门 f t f_t ft,输入门 i t i_t it,输出门 o t o_t ot进行控制的。
在这里插入图片描述
接下来我将LSTM中的每个门及细胞状态的作用进行说明:
    (1)遗忘门: 用来选择性的遗忘某些历史信息,即决定上一时刻的细胞状态有多少需要被遗忘掉。它查看 h t − 1 h_{t-1} ht1(前一个输出)和xt(当前输入),通过sigmoid层将其值限制到0和1之间的数字。1代表完全保留,而0代表彻底遗忘。
    (2)输入门: 用来加强对某些信息的记忆,即决定当前时刻网络的输入数据有多少需要保存到细胞状态。首先,称为“输入门层”的Sigmoid层决定了我们将更新哪些值。 接下来一个tanh层创建候选向量 C t ~ \tilde{C_t} Ct~,该向量将会被加到细胞的状态中。 在下一步中,我们将结合这两个向量来创建更新值。
    (3)当前细胞状态: 经过遗忘和添加后保留下来的记忆信息,即更新后的细胞状态。是通过这个公式计算出来的,其中 f t ∗ C ~ t − 1 f_t*\tilde{C}_{t-1} ftC~t1就代表需要遗忘的记忆, i t ∗ C t ~ i_t*\tilde{C_t} itCt~就代表当前要保留的记忆。
    (4)输出门: 控制当前单元状态有多少需要输出,即用来对长短期记忆信息进行综合考虑,生成输出信号。首先,我们运行一个sigmoid层,它决定了我们要输出的细胞状态的哪些部分。 然后,我们将细胞状态通过tanh(将值规范化到-1和1之间),并将其乘以Sigmoid门的输出,即只输出了我们决定的那些部分。

Attention mechanism

在这里插入图片描述

        接下来我们来了解一下注意力机制。深度学习中的注意力机制从本质上讲和人类的选择性视觉注意力机制类似,核心目标也是从众多信息中选择出对当前任务目标更关键的信息。至于Attention机制的具体计算过程,可以将其归纳为三个阶段:
   (1)计算计算Query和Key_i之间的相似性或者相关性。我们可以将Key看成数据集中的样本特征,Value看成是样本的标签。Query就是待测样本的特征。这里列出了几种常见的度量方法。
    (2)使用使用softmax对stage1中的原始分值进行归一化处理,即得到相应的概率或者权值,Query与Key越相似其权值或者概率就越大。一般采用如下的计算公式:
    (3)通过stage2的结果计算Attention数值。即用其标签乘以相应的权值然后在求和,就得到了最后的Attention值,即查询样本和训练集的相似性。

本文的主要贡献

1. At the modeling level:
       Matching Networks: uses recent advances in attention and memory.
       在模型层面上,作者提出了一个Matching Networks, 一种使用注意力记忆力机制加速学习的网络结构。
2. At the training procedure:
       Learning principle: test and train conditions must match.
       在训练流程上,作者提出训练模型时要遵循测试和训练条件必须匹配。由于在测试时通常使用每个类较少的样本,所以作者提出在训练时也仅用每个类别中很少的样本进行训练。
       训练过程中,每次训练(episode)都会采样得到不同 Support和Batch,所以总体来看,训练时的训练集包含了不同的类别组合,通过这种机制学习,使得模型学会了从不同类别样本中的提取共性部分,比如如何提取重要特征及比较样本相似等。就是相当于让模型学会分辨支持集Support和查询集合Batch中是否相似,所以在面对新类别样本时,该模型也能较好地进行分类。

Matching Networks architecture

在这里插入图片描述
       对于Matching Networks的结构如图所示:左边的四张狗就相当于Support Set,右边的这一张就是Batch Set。
       首先我们定义一个 S → C s S\rightarrow C_s SCs的映射为 P ( y ^ ∣ x ^ , S ) P(\hat{y}|\hat{x} ,S) P(y^x^,S),该映射基于当前的S,对每个未见过的测试样本 x ^ \hat{x} x^ 给出其标签 y ^ \hat{y} y^的概率分布。可以把 y ^ \hat{y} y^理解为一个向量里面是取得每种类别的y的概率。 比如对于右边的 x ^ \hat{x} x^,它取得左边每种类别的概率可以用一个向量进行表示:[0.1, 0.1, 0.1, 0.7]
       模型的训练过程就可以看作是:给定一个有k个样本的支持集S, 对给定的batch样本 x ^ \hat{x} x^,利用上述的映射可以得到预测的 y ^ \hat{y} y^的概率分布。
       在测试过程中,给定一个新的支持集 S ′ S' S,我们可以用训练学习到的模型对每个测试样本 x ^ \hat{x} x^得到他们可能的label y ^ \hat{y} y^为: P ( y ^ ∣ x ^ , S ′ ) P(\hat{y}|\hat{x},S') P(y^x^,S)

对于给定的一个测试样本 x ^ \hat{x} x^计算 y ^ \hat{y} y^的过程为 P ( y ^ ∣ x ^ , S ) P(\hat{y}|\hat{x} ,S) P(y^x^,S), 可以简化为:

y ^ = ∑ i = 1 k a ( x ^ , x i ) y i \hat{y}=\sum_{i=1}^{k}a(\hat{x},x_i)y_i y^=i=1ka(x^,xi)yi ,其中 a ( x ^ , x i ) = e c ( f ( x ^ ) , g ( x i ) ) ∑ j = 1 k e c ( f ( x ^ ) , g ( x j ) ) a(\hat{x}, x_i)=\frac{e^{c(f(\hat{x}), g(x_i))}}{\sum_{j=1}^{k} e^{c(f(\hat{x}), g(x_j))}} a(x^,xi)=j=1kec(f(x^),g(xj))ec(f(x^),g(xi))

其中 a ( x ^ , x i ) a(\hat{x},x_i) a(x^,xi)代表注意力机制, c ( , ) c(,) c(,)是余弦相似度,用来计算 x ^ \hat{x} x^ x i x_i xi之间的相似度。f和g分别是对 x ^ \hat{x} x^ x i x_i xi的嵌入函数。其中f和g被分别参数化为用于图像任务的深度卷积网络(如VGG 或Inception)或用于语言任务的简单形式词嵌入。 y i y_i yi x i x_i xi样本的标签,这样Batch集中的样本就可以通过支持集中的样本进行表示,即一个新样本 y ^ \hat{y} y^的输出是S集中的样本类别基于attention线性组合。

       即我们的 y i y_i yi其实可以看成是One-Hot编码,然后 y ^ \hat{y} y^则就是取得每种类别的概率。然后通过g和f提取特征,然后进行比较,输出其中 y ^ \hat{y} y^中概率最大的那一类就可以作为模型的预测标签。比如对于左边的四种狗使用One-hot进行编码,我们可以将它们分别标注为[1,0,0,0], [0,1,0,0],[0,0,1,0],[0,0,0,1], 我们假设对于 x ^ \hat{x} x^的预测标签取得左图四张狗的概率分别为: 0.1, 0.1, 0.1, 0.7,则通过注意力机制:

y ^ = ∑ i = 1 k a ( x ^ , x i ) y i \hat{y}=\sum_{i=1}^{k}a(\hat{x},x_i)y_i y^=i=1ka(x^,xi)yi

       我们最终求得的y’即概率分布P(y’|x’,S),为[0.1, 0.1, 0.1, 0.7], 这就是模型预测的概率分布,然后我们选择概率最大的类别作为其预测的类别即可。

优化目标

然后在引入一下多分类的损失函数即交叉熵损失,其计算公式为:

l o s s = ∑ i = 1 n y i ∗ l o g ( y ^ i ) loss = \sum_{i=1}^{n}y_i*log(\hat{y}_i) loss=i=1nyilog(y^i)

其中 y i y_i yi是指示变量(0或1),就相当于类别 y y y的One-hot编码后的每位的值; y ^ i \hat{y}_i y^i是对于观测样本属于类别 y i y_i yi的预测概率。n就Support中的样本的类别数量。就比如当我们输入一张狗的图片时,标签的预测 y ^ \hat{y} y^为[0.2,0.7,0.1], 实际标签 y y y为[0,1,0],计算的损失就为0.36。
在这里插入图片描述在这里插入图片描述

       为了计算梯度和更新我们的模型,我们首先从任务T中采样得到标签集L,然后在从L中采样得到support集S和batch集B,进行训练,我们模型的训练目标是:使基于Support集的Batch集中样本标签的预测误差最小化,即最小化Loss。其实这个Loss其实就是对x`的真实标签的预测概率取LogP的相反数,其实可以转换成最大化-Loss即最大化这个式子,所以这个式子就是模型的优化目标:

θ = a r g m a x θ E L ∼ T [ E S ∼ L , B ∼ L [ ∑ ( x , y ) ∈ B l o g P θ ( y ∣ x , S ) ] ] \theta = \underset{\theta}{argmax }E_{L\sim T}[E_{S\sim L, B\sim L}[\sum_{(x,y)\in B}logP_\theta(y|x,S)]] θ=θargmaxELT[ESL,BL[(x,y)BlogPθ(yx,S)]]

全文嵌入(Full Context Embeddings,FCE)

       然后作者提出又全文嵌入即FCE。即作者认为对样本 x ^ \hat{x} x^ x i x_i xi嵌入的时候,需要考虑到支持集S中样本的影响,由于每次训练(episode)都会采样得到不同 Support和Batch,所以嵌入的时候需要受支持集S中样本数据分布的调控,其嵌入过程需要放在整个支持集环境下进行,也就是 f ( x ^ ) f(\hat{x}) f(x^) g ( x i ) g(x_i) g(xi)的特征提取应该受到S集的影响。即将 f ( x ^ ) f(\hat{x}) f(x^) g ( x i ) g(x_i) g(xi)变为 f ( x ^ , S ) , g ( x i , S ) f(\hat{x},S),g(x_i,S) f(x^,S),g(xi,S)

       因此作者采用带有注意力的LSTM网络对样本 x ^ \hat{x} x^进行嵌入。 其中 f ′ ( x ^ ) f'(\hat{x}) f(x^)是对 x ^ \hat{x} x^的嵌入函数,可以使用VGG或者Inception等进行嵌入,K是LSTM的展开步数,g(S)是参与嵌入的支持集S。来让我们看一下经过k步后的展开公式:这允许模型潜在地忽略支撑集S中的某些元素,并集中注意力计算某些元素。最后 f ( x ^ , S ) f(\hat{x},S) f(x^,S)的嵌入结果为最后一步LSTM输出的隐藏层状态。
       同样对于通过S集中样本 x i x_i xi的嵌入, 作者使用 g ( x i , S ) g(x_i,S) g(xi,S)函数利用双向LSTM来嵌入集合的元素,这个函数除了 x i x_i xi外还将整个支持集S作为输入。 g ′ ( x i ) g'(x_i) g(xi)是首先对xi的嵌入(如使用VGG、Inception model)。当S集中的某些样本 x i x_i xi x j x_j xj非常相似时,这会很有用。双向循环神经网络的主体结构是由两个单向循环神经网络组成的。在每一个时刻t,输入会同时提供给这两个方向相反的循环神经网络,而输出则是由这两个单向循环神经网络共同决定。
       最后的实验结果表明,引入了FCE的Matching Network的性能得到了明显的提升。

整个训练过程

       所以整个训练的过程可以通过这张图片展示出,首先,通过嵌入函数 g ( x i , S ) g(x_i, S) g(xi,S) f ( x ^ , S ) f(\hat{x},S) f(x^,S)进行特征的提取,然后通过注意力机制 a ( x ^ , x i ) a(\hat{x},x_i) a(x^,xi)计算权值,然后计算出样本 x ^ \hat{x} x^的预测结果 P ( y ^ ∣ x ^ , S ) P(\hat{y}|\hat{x},S) P(y^x^,S), 然后利用交叉熵损失优化我们的目标进行参数的求解。
在这里插入图片描述

实验

       这篇文章的实验部分都是围绕着一个基本任务:N-way k-shot learning task。 即对N个类别,每个类别给定k(1或5)个样本,并且测试过程中的样本以及类别都是在训练过程中未见过的,即用来判断一个新的样本的类别。作者选择了几个对比模型与匹配网络进行比较。
       Fine Ture代表微调,就是用从已训练好的模型中获得的参数来初始化自己的网络,然后用自己的数据接着训练即对训练的模型进行微小改变;Matching Fn代表匹配函数,其中Softmax就代表一个全连接层经过一个softmax非线性激活在进行匹配,Consine就代表使用余弦距离进行匹配。

Omniglot 数据集上的实验

       首先是在Omniglot上的实验:Omniglot 是一个类似 MNIST 的数据集,一共有来自50个不同字母表的1623个不同手写字符组成,每个字符都是由20个不同的人手工绘制的。其中选择1200种字符进行训练,其余字符用于测试。
       实验的结果如图所示,其中从实验结果我们可以看出,对于1-shot,5-shot,5-way,20-way来说,匹配网络模型均优于其他分类器。但是FCE没什么很大的帮助,在这就没有列出。
在这里插入图片描述
在这里插入图片描述

ImageNet 数据集上的实验

       ImageNet是一个计算机视觉数据集,总共有21841个类别(synsets),14197122幅图像,在这个数据集上作者进行了三个不同的实验,分别对不同的数据集进行训练和预测。
在这里插入图片描述
在这里插入图片描述

miniImageNet

       首先是作者新定义的一个数据集 miniImageNet —— 一共有100个类别,每个类有600个样本。其中80个类用于训练20个类用于测试。在miniImageNet上的实验,我们可以看出MatchingNets的效果比其他分类器效果好,同时对于FCE的作用我们也可以看出,提升了大概两个百分点。

randImageNet

       在训练集中随机去除了118个label的样本,并将这118个标签的样本用于之后的测试。

dogsImageNet

       从ImageNett数据集中移除了所有属于狗这一大类的样本(一共118个子类),之后用这118个狗的子类样本做测试。
       作者解释到在dogsImageNet上的实验更类似于细粒度分类,如果调整训练策略,从细粒度的标签集合中取样,而不是从ImageNet类别中均匀地取样,模型的结果可以得到提升

Penn TreeBank 数据集上的实验

       作者还介绍了一个One-shot 语言任务:给定一个包含缺失单词的查询语句,以及一组支持语句(每个句子都有一个缺失单词和一个对应的1-hot标签),从支持集中选择与查询语句最匹配的标签。这里我们展示了一个例子,但是注意右边的单词没有被提供,集合的标签是以1-hot-of-5向量给出的。这个实验是在Penn Treebank数据集上进行的,结果如下图所示,但是实验结果并不理想。
#### 四级标题
在这里插入图片描述

参考

论文原文:Matching Networks for one shot learning
【译】理解 LSTM(Long Short-Term Memory, LSTM) 网络
注意力机制的基本思想和实现原理(很详细)(第二篇)
交叉熵损失函数原理详解
基于匹配网络(Matching Networks)的FSL方法简述(一)

  • 6
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值