小样本学习&元学习经典论文整理||持续更新
核心思想
本文提出一种采用直推式传播网络(Transductive Propagation Network,TPN)的小样本学习算法。在介绍本文之前,我们首先了解一下什么是直推式学习(Transductive Learning),我们常见的有监督学习方法其实属于归纳学习(Inductive Learning),也就是训练集是带有真实标签的,而测试集是不带有真实标签的,且二者之间不存在重合的部分。但直推式学习则是将带有标签的训练集和不带标签的测试集都输入到网络进行训练,然后再预测这部分测试集的结果(一个形象的例子就是在布置课后作业时,把考试原题给你了,但不给你答案)。这种方法与半监督学习也有不同,半监督学习中所使用的无标签数据集和测试集是不相同的两个数据集,而在直推学习中二者是同一个数据集。了解了直推学习后,我们来介绍下本文提出的TPN算法。
如图所示,本文提出的算法包含四个步骤:特征提取,构建图模型,标签传播和损失计算。首先将训练的支持集(带有标签)和查询集(不带标签)一起输入到特征提取网络
f
φ
f_{\varphi}
fφ中得到对应的特征向量。然后,利用一个卷积神经网络
g
ϕ
g_{\phi}
gϕ来构建图模型,所谓构建图模型,其实就是将每个样本当作图中的结点,计算每个节点之间权重
W
i
j
W_{ij}
Wij的过程。常见的方式为高斯相似性函数
其中
d
(
,
)
d(,)
d(,)是一个距离度量函数(如欧式距离),
σ
\sigma
σ表示长度范围参数,这个参数对图模型有较大的影响,需要谨慎选择,但作者发现如果在元学习的过程中去学习这个参数并不是解决问题的根本方式,因此作者选择利用一个CNN,为每个样本都学习一个特有的
σ
\sigma
σ参数,即
σ
i
=
g
(
f
(
x
i
)
)
\sigma_i=g(f(x_i))
σi=g(f(xi)),然后可得到两个节点间的权重为
作者还利用规范化的图拉普拉斯算子处理了节点间的权重
W
W
W,得到
S
=
D
−
1
/
2
W
D
−
1
/
2
S=D^{-1/2}WD^{-1/2}
S=D−1/2WD−1/2,其中
D
D
D是一个对角矩阵,且
(
i
,
i
)
(i,i)
(i,i)处的值为
W
W
W矩阵中第
i
i
i行的值之和。得到图模型之后,下面的工作就是要利用带有标签的支持集样本来推测未带标签的查询集样本了,这一过程称之为标签传播。作者首先构建了一个非负矩阵
F
\mathcal{F}
F其尺寸为
(
N
×
K
+
T
)
×
N
(N\times K+T)\times N
(N×K+T)×N,其中
N
N
N表示支持集中包含的类别数,
K
K
K表示每个类别包含的样本数,
T
T
T表示查询集样本数。矩阵
F
∈
F
F\in\mathcal{F}
F∈F中的值
F
i
j
{F}_{ij}
Fij可以理解为第
i
i
i个样本属于第
j
j
j个类别的概率(此时还没经过归一化处理,所以还不是最终的概率值)。矩阵
F
{F}
F通过迭代训练的方式进行更新,迭代过程如下
式中
F
t
F_t
Ft表示第
t
t
t次迭代的预测矩阵,
Y
∈
F
Y\in\mathcal{F}
Y∈F,如果
Y
i
j
=
1
Y_{ij}=1
Yij=1则表示第
i
i
i个样本属于支持集,且标签为
j
j
j,否则
Y
i
j
=
0
Y_{ij}=0
Yij=0。经过迭代后,
F
F
F收敛为
并利用softmax函数将其转化为概率,得到最终的预测结果。
实现过程
网络结构
特征提取网络由4层卷积块构成,图构建网络包含两个卷积层和两个全连接层。
损失函数
采用交叉熵损失函数
创新点
- 采用直推式的方法进行小样本学习
- 利用图模型和标签传播的方式,对查询集样本类别进行预测
算法评价
本文应该是首篇完全采用直推式学习的方法进行小样本学习的文章,整个思路与基于图神经网络的算法比较接近,但最后采用的标签传播方式进行类别预测的确是之前读的文章中没有出现过的。由于直推式学习本身具备的优势,本文的分类效果与同时期的其他算法相比都有一定的领先,该方法也具备较大的研究潜力。
如果大家对于深度学习与计算机视觉领域感兴趣,希望获得更多的知识分享与最新的论文解读,欢迎关注我的个人公众号“深视”。