Prototypical Networks for Few-shot Learning
Introduction
- 小样本学习定义: 小样本学习任务要求分类器能够适应在训练集中未出现过的新类别,并且对于每个新类别,只有极少数的样本可供学习。
- 问题的重要性: 尽管这是一个非常困难的问题,但人类展现出了在极端条件下(如单样本学习)进行准确分类的能力。因此,开发能够在有限样本下有效学习的机器学习模型具有重要意义。
- 现有方法的局限性(2017年之前): 直接在新数据上重新训练模型会导致严重的过拟合问题。虽然已有一些方法在小样本学习上取得了进展,但仍需寻找更有效的解决方案。
- 原型网络的提出: 为了解决小样本学习中的过拟合问题,作者提出了原型网络(Prototypical Networks),这是一种基于原型的简单方法,通过学习输入空间到嵌入空间的非线性映射来进行分类。
- 原型网络的优势: 与最近的一些复杂方法相比,原型网络展现了更简单的归纳偏置(inductive bias),在有限数据条件下表现出色。此外,原型网络的设计决策简单,但在性能上却能显著超越涉及复杂架构选择和元学习的方法。
- 零样本学习的扩展: 作者还扩展了原型网络以处理零样本学习任务,其中类别是通过元数据而非标记样本来定义的。
原型网络
小样本问题建模
support set:
S
=
{
(
x
1
,
y
1
)
,
.
.
.
.
.
.
.
.
.
,
(
x
N
,
y
N
)
}
S=\{(x_1,y_1),.........,(x_N,y_N)\}
S={(x1,y1),.........,(xN,yN)}
x i ∈ R D , y i ∈ { 1 , 2 , 3 , . . . . . . . K } x_i \in \mathbb{R}^D , y_i \in \{1,2,3,.......K\} xi∈RD,yi∈{1,2,3,.......K}
S k = { ( x 1 , y 1 ) , . . . . . . . ( x i , y i ) . . . . . . ∣ y i = K } S_k=\{(x_1,y_1),.......(x_i,y_i)...... | y_i =K\} Sk={(x1,y1),.......(xi,yi)......∣yi=K}
其中xi是D维度的特征向量,yi是xi对应的标签,yi的种类有1~K不同类型,Sk为标签均为K的子集
原型网络建模
模型学习出的函数 f ϕ : R D → R M f_\phi :\mathbb{R}^D \to \mathbb{R}^M fϕ:RD→RM,原型网络学习一个编码函数,将输入的D维的xi,编码为M维度的 f ϕ ( x i ) f_\phi(x_i) fϕ(xi)。然后按类别不同,对每个类别计算原型 c k c_k ck。
c k = 1 S K ∑ ( x i , y i ) ∈ S k f ϕ ( x i ) c_k=\frac{1}{S_K} \sum_{(x_i,y_i) \in S_k}f_{\phi}(x_i) ck=SK1∑(xi,yi)∈Skfϕ(xi)
每类样本的编码求均值得到原型。
针对某个输入样本,如何确定他所属类别:
对此样本和所有原型求距离,然后算softmax。
损失函数
J
(
ϕ
)
=
−
log
p
ϕ
(
y
=
k
∣
x
)
J(\phi)=-\log_{p_{\phi}}(y=k|x)
J(ϕ)=−logpϕ(y=k∣x),最小化真实类别负对数。
训练时的“episode”是通过从训练集中随机选择一部分类别形成的,然后在每个类别中选择一部分样本作为支持集(support set),剩余的部分作为查询点(query points)。
原型网络的整体流程(伪代码):
符号表示:
N大训练集样本总数,K大训练集类别总数。每个“episode”是
N
c
N_c
Nc ways
N
s
N_s
Ns shots with
N
Q
N_Q
NQ queries (相当于每个eposide每个类别的的测试集样本数有
N
Q
N_Q
NQ个)。
R
A
N
D
O
M
S
A
M
P
L
E
(
S
,
N
)
RANDOMSAMPLE(S, N )
RANDOMSAMPLE(S,N)从S集合采样N个。
- 疑惑:这里求原型为什么不除Ns呢???????
本文的距离度量采用平方欧几里得距离
之后作者论述了选择距离度量函数的依据,和实验结果,年代久远,不深究了。