【论文阅读】Prototypical Networks for Few-shot Learning

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

本文结合论文和youtube上的视频[Few-shot learning][2.2] Prototypical Networks: intuition, algorithm, pytorch code来整理一下对prototypical networks在few-shot领域的理解。


一、论文

摘要

问题:少镜头分类问题(在只给定少量实例的情况下,分类器必须推广到未看到的新类)。
提出的解决方案:Prototypical Networks学习一个度量空间,在该空间中,可以通过计算得到每个类的原型表示的距离来执行分类。
优点:与最近少镜头学习方法相比,它们反映了一种更简单的归纳偏差,这种偏差在优先数据的状态下是有益的,取得了出色的结果。
分析:我们表明一些简单的设计决策可以比最近涉及复杂架构选择和元学习的方法产生实质性的改进。
扩展:扩展到了0样本学习,在CU-Birds dataset中获得了最先进的结果。

方法

在这里插入图片描述
这是在度量空间中,左边是few-shot,是计算每个类的embedded支持例的平均值得到ck。右边是zero-shot,通过embedding类别元数据vk生成的。在每一种情况下,embedded查询点是通过softmax对类原型的距离进行分类。pφ(y = k|x) ∝ exp(−d(fφ(x), ck)).

ck中心就是每一个类通过embedding函数得到的支持点的平均值。就是embedding相当于一个有很多维的一个空间中的一个点(我觉得类似特征提取得到得特征,这些特征得到的相当于一个高维空间中得坐标,每个类的支持点坐标不一定相同但是相近,它们的平均值可以近似看作这个类在这个高维空间中聚类的那个中心点)。
就是属于哪个类的概率p的计算是通过softmax函数得到的。p(y=k|x)是到自己true类别的距离的相反数的exp()比到其他类别距离的相反数的exp的和。loss就是-log(p(y=k|x))。
下面是loss的计算。
在这里插入图片描述
距离:距离计算有很多公式,对于一类特定的距离函数,称为正则布雷格曼散度[4],原型网络算法等效于对具有指数族密度的支持集执行混合密度估计。
原型计算可以从支持集上的硬聚类来看,每个类一个聚类,每个支持点分配给其相应的类聚类。对于布雷格曼散度,已经表明[4],达到到其指定点的最小距离的聚类代表是聚类均值。因此,当使用布雷格曼散度时,公式(1)中的原型计算在给定支持集标签的情况下产生最优聚类代表。所以他才取的均值。
后面就是对指数组混合模型的一些数学公式,我暂时看不懂。

重新解释为线性模型
当我们使用欧几里得距离 d(z, z′) = ‖z − z′‖2 时,方程 (2) 中的模型等效于具有特定参数化的线性模型 [21]。若要查看此内容,请展开指数中的项:
− ‖ f φ ( x ) − c k ‖ 2 = − f φ ( x ) T f φ ( x ) + 2 c k T f φ ( x ) − c k T c k −‖fφ(x) − c_k‖2 = −fφ(x)^Tfφ(x) + 2c^T_k fφ(x) − c^T_k c_k fφ(x)ck‖2=fφ(x)Tfφ(x)+2ckTfφ(x)ckTck
等式中的第一项相对于类k是常数,所以他就变成线性的函数了。
2 c k T f φ ( x ) − c k T c k = w k T f φ ( x ) + b k , w h e r e w k = 2 c k a n d b k = − c k T c k 2c^T_k fφ(x) − c^T_k ck = w^T_k fφ(x) + b_k, where w_k = 2c_k and b_k = −c^T_k c_k 2ckTfφ(x)ckTck=wkTfφ(x)+bk,wherewk=2ckandbk=ckTck

与匹配网络比较:原型网络与匹配网络在少数镜头情况下不同,在单镜头场景中具有等效性。
设计选择:Distance metrics, Episode composition

二、视频

在这里插入图片描述
先讲了聚类算法是怎么进行的。
在这里插入图片描述
然后讲了prototype的运行方式。对一个3-way 5-shot任务来说,他有五个支持图片,每个支持图片进入到一个编码器生成zi,这些zi做平均mean得到ci。3 way一次有三类。查询图像经过相同的编码器得到za,计算与这三个zi的距离,经过softmax函数,得到属于每一个类的概率。
在这里插入图片描述
然后对loss的计算过程进行梳理。
在这里插入图片描述
在这里插入图片描述
这里是伪代码。我感觉主要有两个步骤一个根据支持点得到z_proto(详细一点就是让所有输入通过网络得到z,根据每一类的支持点的z取平均得到每一类的z_proto),第二步计算距离,得到loss(用距离函数计算z_query和z_proto的距离,使用softmax函数得到x_query属于每一类的概率,然后根据query的标签计算loss)。
在这里插入图片描述
列出了prototype Networks的优缺点。


总结

原型网络的简单性和有效性使其成为少镜头学习的有前途的方法。

  • 8
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值