原型网络 Prototypical Network

原型网络 - Prototypical Network

原型网络出自下面这篇论文。
Snell J, Swersky K, Zemel R S. Prototypical networks for few-shot learning[J]. NIPS 2017.

请添加图片描述

原理

原理和聚类有点相似

孪生网络的缺点就是要对比目标和过去每个样本之间的相似度,从而分析目标的类别,而原型网络就提出,把样本投影到一个空间,计算每个样本类别的中心,在分类的时候,通过对比目标到每个中心的距离,从而分析出目标的类别。

  • support set 投影到新的特征空间,也就是nlp中的 word embedding,可以通过神经网络,把输入转化伪一个新的特征向量
  • 基于 support set 计算每个类别的均值表示该类原型 prototype
  • 用 SGD 等优化算法,优化距离损失,使得同一类的向量之间的距离比较接近,不同类的向量距离比较远
  • 通过最小化损失函数,从而优化最初的 embedding
形式化描述

核心思想:在N-way K-shot设置下,通过计算 support set 中的嵌入中心,然后衡量新样本与这些中心的距离来完成分类。

  • support set:包含少量标注的样本
  • query set:包含未标注样本,和support set的样本空间一致,不能和 support set 有重复
  • 计算嵌入中心公式,简单的取平均
    • S k S_k Sk:类别为 k 的 support set
    • f θ f_θ fθ:embedding 函数
    • x i x_i xi:输入
      c k = 1 ∣ S k ∣ ∑ i f θ ( x i ) c_k = \frac{1}{|S_k|} \sum_i f_θ(x_i) ck=Sk1ifθ(xi)
  • 计算新样本 x 到每个类别 i 的嵌入中心的距离: d i = d ( f θ ( x ) , c i ) , i = 1 , . . . , N d_i = d(f_θ(x), c_i), i = 1,...,N di=d(fθ(x),ci),i=1,...,N,然后再用softmax对距离做映射,得到每个类别的概率 y ^ i = s o f t m a x ( d 1 , . . . , d k ) \hat y_i = softmax(d_1,...,d_k) y^i=softmax(d1,...,dk)
  • 训练目标通过SGD优化交叉熵损失函数:
    • y y y:真实值
    • y ^ \hat y y^:预测值
      L ( y , y ^ ) = − ∑ i = 1 N ′ y i l o g y ^ i L(y, \hat y) = - \sum^{N'}_{i = 1}y_i log\hat y_i L(y,y^)=i=1Nyilogy^i

参考资料:
Snell J, Swersky K, Zemel R S. Prototypical networks for few-shot learning[J]. NIPS 2017.
元学习系列(二):Prototypical Networks(原型网络)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值