原型网络 - Prototypical Network
原型网络出自下面这篇论文。
Snell J, Swersky K, Zemel R S. Prototypical networks for few-shot learning[J]. NIPS 2017.
原理
原理和聚类有点相似
孪生网络的缺点就是要对比目标和过去每个样本之间的相似度,从而分析目标的类别,而原型网络就提出,把样本投影到一个空间,计算每个样本类别的中心,在分类的时候,通过对比目标到每个中心的距离,从而分析出目标的类别。
- support set 投影到新的特征空间,也就是nlp中的 word embedding,可以通过神经网络,把输入转化伪一个新的特征向量
- 基于 support set 计算每个类别的均值表示该类原型 prototype
- 用 SGD 等优化算法,优化距离损失,使得同一类的向量之间的距离比较接近,不同类的向量距离比较远
- 通过最小化损失函数,从而优化最初的 embedding
形式化描述
核心思想:在N-way K-shot设置下,通过计算 support set 中的嵌入中心,然后衡量新样本与这些中心的距离来完成分类。
- support set:包含少量标注的样本
- query set:包含未标注样本,和support set的样本空间一致,不能和 support set 有重复
- 计算嵌入中心公式,简单的取平均
- S k S_k Sk:类别为 k 的 support set
- f θ f_θ fθ:embedding 函数
-
x
i
x_i
xi:输入
c k = 1 ∣ S k ∣ ∑ i f θ ( x i ) c_k = \frac{1}{|S_k|} \sum_i f_θ(x_i) ck=∣Sk∣1i∑fθ(xi)
- 计算新样本 x 到每个类别 i 的嵌入中心的距离: d i = d ( f θ ( x ) , c i ) , i = 1 , . . . , N d_i = d(f_θ(x), c_i), i = 1,...,N di=d(fθ(x),ci),i=1,...,N,然后再用softmax对距离做映射,得到每个类别的概率 y ^ i = s o f t m a x ( d 1 , . . . , d k ) \hat y_i = softmax(d_1,...,d_k) y^i=softmax(d1,...,dk)。
- 训练目标通过SGD优化交叉熵损失函数:
- y y y:真实值
-
y
^
\hat y
y^:预测值
L ( y , y ^ ) = − ∑ i = 1 N ′ y i l o g y ^ i L(y, \hat y) = - \sum^{N'}_{i = 1}y_i log\hat y_i L(y,y^)=−i=1∑N′yilogy^i
参考资料:
Snell J, Swersky K, Zemel R S. Prototypical networks for few-shot learning[J]. NIPS 2017.
元学习系列(二):Prototypical Networks(原型网络)