元学习—高斯原型网络

本文介绍了高斯原型网络,一种改进的原型网络结构,用于小样本学习任务。该模型不仅输出编码向量,还估计了每个样本的质量,通过学习编码的协方差矩阵来预测置信区间。通过调整不同的协方差矩阵构建方法,模型能够在编码空间中学习依赖方向的距离度量,从而提高分类效果。在Omniglot数据集上的实验表明,高斯原型网络相对于传统原型网络具有优势。
摘要由CSDN通过智能技术生成

元学习—高斯原型网络

本博客源于文献《Gassian Prototypical Networks for Few-Shot Learning on Omniglot》,只选择的有关模型描述的部分进行了翻译,有兴趣的读者可以进行参考全文。同时,该文章没有仔细的介绍原型网络的基本结构,有兴趣的读者可以参考我的另外一篇博客元学习——原型网络(Prototypical Networks)

1 背景介绍

1.1 高斯原型网络引入

在本篇文献中,我们基于原型网络提出了一个新的结构,并且将其在Omngilot数据集上进行训练和测试。在原型网络中将图片映射到编码向量中,并且使用类似于聚类中的距离计算方式来进行分类。原型网络将每一个批次的数据分为支持集查询集图片,并且使用支持集生成的编码向量来定义“类别元”,即对于一个类别的描述。这些近似的描述被用于后续的分类。

对于我们的模型,我们称之为高斯原型网络,其将图片映射成编码向量和一个对于图片质量的估计。生成编码向量的同时,关于编码向量的置信区间也被生成,这个置信区间通过高斯协方差矩阵来进行表征。高斯原型网络学习一个基于编码空间的依赖于方向和类别的距离度量。我们将展示使用额外的可训练的参数要比传统的原型网络效果更好。

我们的目标是通过让我们的模型在孤立数据点中解释其置信度来说明我们获得了更好的效果。我们也有意的破坏了我们数据,以此来引入我们的模型对于噪声,异构数据集(即孤立数据的权重将对模型的表现产生重要的影响)的可扩展性。

1.2 相关工作

对于一般的不存在参数的模型,比如KNN模型,其是一个理想的小样本分类器,因为其允许对于未见类别的成立性(即仅仅通过样本的距离就判断分类,但是并没有确定是哪一个分类)。然而,这种方法对于距离的度量方式是十分敏感的。直接使用输入空间内的距离计算不会产生太高的精确性,因为图片的类别和像素点之间的联系不是线性的。

一个直接的改进方案是通过为输入学习一个编码来进行KNN分类,这种方法的确取得了不错的效果。另外一种方案是使用匹配网络,其可以有效的学习成对的图片之间的距离度量。模型值得注意的一个特点是其学习框架,其中每一个小的批次通过对于数据集中的类别进行子采样和对于每一个类别的子采样来尝试去最小化数据缺乏的测试条件。这种方法在小样本的分类中,提高了模型的效果。因此,我们后面也会使用到这个采样的方法。

为了替换直接使用数据集,最近被提出来使用LSTM来进行预测更新出一个分类器,其中每一个输入是一个小的batch。这种方法也被称为“元学习”的方法。元学习在Omniglot上表现的效果很好。一个未知任务的,基于一般卷积形式的元学习器已经被提出。对于参数化和非参数化结合的方法最近在小样本的学习上取得了不错的效果。

我们所提出的是一个对于图像分类的具体方法,并且没用尝试通过使用元学习的方式来解决这个问题。我们在将图像映射成编码向量和使用聚类计算的基础上来进行实现。我们模型中的新的特征是通过学习一个依赖于图片协方差矩阵来预测出每个样本的置信度。其允许通过图片所构建出的更加丰富的编码空间。最终利用方向和与类别信息的度量机制来进行最终的分类。

2 方法

在这篇文献中,我们首先来探索传统的原型网络,我们对于传统的进行网络进行结构上的扩展,其称之为为高斯原型网络,其允许模型通过预测编码和其置信区间(使用高斯协方差矩阵表示)来反应出每一个样本的质量。

传统的原型网络中包括了一个编码器来将图片映射成一个编码向量。一个batch的数据博包括了可训练数据集中的一个子集。在每一轮迭代中,每一个类别的图片被随机分成支持集,查询集。支持集的编码被用于定义“类别元”(用向量表示一个类别)。查询图片与类别元的近似性被用于分类。

传统的原型网络和高斯原型网络在编码器的结构上没有什么不同,关键区别在于编码器输出的解释,使用方式以及如何构造嵌入空间上的度量。在高斯原型网络中,编码器输出的一个部分被用于构建对于编码的协方差矩阵。其允许我们的模型来反映预测能力和单独样本的质量。

2.1 编码器

我们使用一个多层的卷积神经网络,同时放弃了最终的全连接网络来对于图像进行编码,生成高维欧式向量的方法。对于传统的原型网络而言,编码器函数以一个图片I作为输入,并且将其转换成一个向量x,其定义公式为:
e n c o d e r ( W ) : I ∈ R H ∗ W ∗ C − > x ∈ R D encoder(W):I∈R^{H*W*C}->x∈R^D encoder(W):IRHWC>xRD
其中,H,W,C为输入图片的长,宽和通道数。D为编码之后的向量维度,其为模型的一个超参数。W为编码器中的参数。
对于一个高斯原型网络,其输出是一个编码向量 x ∈ R D x∈R^D xRD和其相关的协方差矩阵 ∑ ∈ R D ∗ D ∑∈R^{D*D} RDD,则其计算公式为:
e n c o d e r ( W ) : I ∈ R H ∗ W ∗ C − > [ x , s r a w ] ∈ [ R D , R D s ] encoder(W):I∈R^{H*W*C}->[x,s_{raw}]∈[R^D,R^{D_s}] encoder(W):IRHWC>[x,sraw][RD,RDs]
其中 D s D_s Ds是协方差矩阵的维度。对于协方差矩阵的构建,我们探索了下面三种方法:

  1. 半径: 即 D s = 1 D_s=1 Ds=1,即协方差仅仅是一个数值 s r a w ∈ R 1 s_{raw}∈R^1 srawR1。,这种方法为每一个图像生成一个置信区间的表征。因此置信区间所构造的协方差矩阵可以有如下的表现形式:
    ∑ = d i a g ( σ , σ , σ , σ , σ , . . . ) ∑=diag(σ,σ,σ,σ,σ,...) =diag(σ,σ,σ,σ,σ,...)
    其中每一个σ通过原始的输出 s r a w s_{raw} sraw来计算。通过这种方式计算出来的置信度是一个数值,显然其不具有方向的敏感性。这种方法在Omniglot数据集上提供了对于额外参数最为有效的应用。我们猜测这种表示是根据具体的数据所产生的,并且更少的同构数据能够有更加复杂的置信预测。
  2. 对角线, D s = D D_s=D Ds=D,即协方差预测的维度和编码向量的维度是一致的,其中 s r a w ∈ R D s_{raw}∈R^D srawRD通过每一张图片来生成,用来表征编码向量的置信期间。因此,这个协方差矩阵拥有如下的形式:
    ∑ = d i a g ( σ − ) ∑=diag{(σ^-)} =diag(σ)
    其中 σ − σ^- σ表示的通过 s r a w s_{raw} sraw计算出来的向量结果。这种方式允许网络来解释关于一个样本依赖于方向的置信区间。
  3. 全方式:即每一个样本点都可以输出一个协方差的矩阵。这种方法过于复杂,因此后面没有进行进一步的探索。

我们使用下采样的方法,对于Omniglot数据集中的数据点采样28281个像素点作为输入,一个4层的CNN和2*2最大池化层,最后的池化层将数据整理成 1 ∗ 1 ∗ ( D + D s ) 1*1*(D+D_s) 11(D+Ds)的形式,其中D表示的编码的维度,加上对于协方差矩阵的维度 D s D_s Ds,两个部分与最后一层卷积核的数量相同。对于原始的输入,如果维度不满足需求,使用1进行填充。最后一层和一个全连接层类似。

我们研究了两种不同的编码维度。

  1. 小结构:3*3的核,核的数量为[64,64,64,D],若使用半径,则为([64,64,64,D+1]),若使用对角线,则为([64,64,64,D+D])。
  2. 大结构:3*3的核,核的数量为[128,256,512,D],若使用半径,则为([128,256,512,D+1]),若使用对角线,则为([128,256,512,2D])。在上述的两种结构中,D的取值分别为128,256,512

进一步,我们探索了四种不同的方法来对编码器输出的源协方差矩阵转换成实际的协方差矩阵 。因此,实际上,我们最终使用的是编码器输出的协方差矩阵的逆矩阵。即 S = ∑ − 1 S=∑^{-1} S=1,这里我们将编码器的原始输出的协方差矩阵定义为 S r a w S_{raw} Sraw,所有计算方法如下所示:
a. S = 1 + s o f t p l u s ( S r a w ) S=1+softplus(S_{raw}) S=1+softplus(Sraw),其中 s o f t p l u s ( x ) = l o g ( 1 + e x ) softplus(x)=log(1+e^x) softplus(x)=log(1+ex),因为 s o f t p l u s ( x ) > 0 softplus(x)>0 softplus(x)>0,这保证了S>1,并且编码器仅仅能够使得数据点的重要性降低。同时S也不会受到上面的限制,我们的模型在初始的训练中采用了这种策略。
b. S = 1 + s i g m o i d ( S r a w ) S=1+sigmoid(S_{raw}) S=1+sigmoid(Sraw),其中 s i g m o i d ( x ) = l o g ( 1 + e x ) sigmoid(x)=log(1+e^x) sigmoid(x)=log(1+ex),因为sigmoid(x)>0,这保证了S>1,并且编码器的输出只是降低了数据点的重要性。S的值通过上述被限定界限,即S<2,因此编码器更加具有限制性。
c. S = 1 + 4 s i g m o i d ( S r a w ) S=1+4sigmoid(S_{raw}) S=1+4sigmoid(Sraw),因此 1 < S < 5 1<S<5 1<S<5,我们使用这种方法来探索协方差域的变化对于模型表现的影响。
d. S = o f f s e t + s c a l e ∗ s o f t p l u s ( S r a w / d i v ) S=offset+scale*softplus(S_{raw}/div) S=offset+scalesoftplus(Sraw/div),其中offset和scale和div初始为1.0,并且这些参数可以被训练。我们最好的模型是采用这种策略来进行后期的训练,相比于a中所提到的计算方式,这种方法更加灵活。

2.2 小批次训练

原型网络的一个核心是使用小批次训练的策略。在训练过程中,一个子集 N c N_c Nc意味着从训练集中的所有类别进行采样。对于这些类别, N s N_s Ns表示随机采样的支持集数据,同时 N q N_q Nq表示查询集数据。对于支持集中的数据的编码用于定义一个具体的“类元”信息,查询集中的样本用于分类并且进行计算损失。对于高斯原型网络,每一个样本的编码的协方差同时也会被预测。

对于高斯原型网络,协方差矩阵的半径或者对角线的结果与编码同时输出,然后这些值将被用于对于编码向量进行加权,同时计算对于这个类别的协方差矩阵,最终,我们定义查询集样本到类元之间的距离 d c ( i ) d_c(i) dc(i),其中c表示类别,i表示样本。
d c 2 ( i ) = ( x i − p c ) T S c ( x i − p c ) d_c^2(i)=(x_i-p_c)^TS_c(x_i-p_c) dc2(i)=(xipc)TSc(xipc)
这里, x i x_i xi p c p_c pc分别表示的是样本i和类别c的原型向量。并且 S c = ∑ c − 1 S_c=∑_c^{-1} Sc=c1,为原始协方差的逆矩阵。因此,高斯原型网络能够在编码空间中去学习一个类别信息,和依赖于方向的距离度量机制。我们发现训练的速度和其准确性及其依赖于如何利用距离结果来构建损失函数。我们通过线性的欧式距离来得出了最好的优化方式。即上述的距离计算方式。具体的损失函数的定义将在下面的算法描述中给出。

2.3 定义一个类别

原型网络中的一个重点是从支持集中创建一个类别元,我们提出了一个基于权重的编码方式的线性组合来生成类别元信息,这里我们设类别c中包含支持集图片 I i I_i Ii,并且编码器的输出为 x i x_i xi,对于类别c而言,计算出来的协方差矩阵的逆矩阵为 S i c S_i^c Sic,其对角线元素为 s i c s_i^c sic,则类别元信息的计算公式定义为:
p c = ∑ i s i c ∗ x i c ∑ i s i c p_c=\frac{∑_is_i^c*x_i^c}{∑_is_i^c} pc=isicisicxic
其中*表示对位相乘,对于类别的c的协方差矩阵的计算为:
s c = ∑ i s i c s_c=∑_is_i^c sc=isic

2.4 模型评估

为了在测试集中评估模型的准确性,我们对测试集中的所有数据进行分类预测,同时对不同的支持集中的样本数量 N s = k N_s=k Ns=k中的k进行了不同的设置。因为我们没有使用验证集,因此我们在计算的时候,对所有的实验去前5个最高的实验结果,然后进行求平均和归一化的操作。

2.5 算法描述

在这里插入图片描述

3 总结

与传统的原型网络相比,高斯原型网络的包括以下几个方面的改进:

  1. 编码器同时输出编码和置信度结果。
  2. 修改了原始的类别元信息的计算方式。
  3. 同时计算了每一个类别的协方差矩阵结果。
  4. 使用了依赖于方向的距离计算方式。

最后,我们用两张图来理解:

在这里插入图片描述

上述图中,实心的点表示的是编码结构,围绕着某一个节点的椭圆是该点的置信期间,而对于不同的颜色的大的椭圆体是整个类别的置信区间。进一步,对于置信期间的描述为:
在这里插入图片描述
其中灰色的节点为查询集中的数据,基于方向,置信期间来计算查询集样本与各个类别元的距离结果,进而完成相关的分类任务。

4 参考

  1. https://arxiv.org/pdf/1708.02735.pdf
  • 4
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值