目录
原型网络(Prototypical Networks)
1. 主要思想
把样本空间投影(嵌入到一个低维空间),利用样本在低维空间的相似度做分类。类似k-means聚类算法,在低维空间中找到每个分类的聚类中心。 用距离函数测新的样本的分类。
2. 模型
样本: K个分类,每个类 N 个样本。
把N分成 NS 和 NQ (N=NS+NQ)。
对应的样本集合分别记为 Sk 支持集( support examples)和 Qk 查询集(query examples)。
低维映射: 神经网络函数fφ(x)把样本x映射到嵌入空间。
每个类的聚类中心(原型):
利用支持集的样本,在嵌入空间中得到第k个类的聚类中心,即支持集第k类样本的平均值。这个中心称为第k类的原型。
目标函数:
(1) 距离函数:给出在嵌入空间的距离函数。
(2) 刻画样本x属于哪个类:
在知道每类样本的聚类中心后,我们就可以刻画一个样本x属于哪个类,用距离函数和softmax函数表示。x属于第k个分类的概率如上图。
(3) 求网络fφ(x)的参数 φ 用到的目标函数 J:
已知样本x对应的第k类:
随机梯度下降法最小化目标函数:概率的-log值,得到最优参数 φ 。
3. 算法
由 Sk 支持集(support examples)的 NS 个样本来确定每个类的聚类中心。
用 Qk 查询集(query examples)的 NQ 个样本计算目标(损失)函数 J 。
这里没有用全部的 K 个类,而是只用了 NC 个类 ( NC ≤ K ) 。
注意:
得到损失函数 J 后,用随机梯度下降法更新嵌入函数的参数 φ。
后文用了regular Bregman divergences距离函数和混合密度分布,还用欧氏距离将原型网络解释为线性模型。
4. 少样本和零样本学习
零样本学习不同于少样本学习,其聚类中心 ck 不是由支持集样本生成的,而是直接给出了元数据(meta-data)样本向量 vk (可以由原始数据等生成),再由 vk 单独嵌入生成 ck 。详见下图:
5. 实验
N-way K-shot:N个类,每个类K个样本
(1) 数据集Omniglot上少样本分类