Few-Shot Segmentation Without Meta-Learning: A Good Transductive Inference Is All You Need?

CVPR_21

Abstract

  • 在小样本分割任务中执行推理 i n f e r e n c e inference inference的方式对性能有实质性的影响——这是支持元学习范式的文献中经常忽视的一个方面
  • 提出 t r a n s d u c t i v e   i n f e r e n c e transductive \ inference transductive inference: 对给定的 q u e r y   i m a g e query \ image query image, 统计其未打标的像素的信息,优化包含三个互补相的新损失
  • 使用了提取特征的简单线性分类器,其计算负荷与归纳推理相当,并且可以在任何基础训练之上使用

1 Introduction

为了实现小样本学习,元学习得到了很大关注,现状是:

  • 最近的几项图像分类工作观察到元学习可能在标准的1或5分类基准之外具有有限的泛化能力——比基线方法表现不好
  • 大量的研究工作集中在基础训练的专门架构和情景训练方案的设计上
  • 现有的元学习方法在跨领域(cross-domain)场景中竞争力较弱

元学习内涵的可能假设:

  • 情景(episodic) 训练方式,本身隐含的假设了 t e s t i n g   t a s k testing \ task testing task的结构与 m e t a − t r a i n i n g meta-training metatraining阶段的 t r a i n i n g   t a s k training \ task training task结构相似
  • base类和novel类往往都是从同一个数据集中采样得到

这些假设可能会限制现有的少镜头分割方法在现实场景中的适用性,同时上述的现状,让作者怀疑元学习在小样本学习上的相关性

贡献:

  • 放弃元学习,重新考虑一个监督基类在训练期间进行特征提取的简单 c r o s s − e n t r o p y cross-entropy crossentropy——损失函数
  • 提出了一种比现有方法更好地利用支持集监督的 t r a n s d u c t i v e   i n f e r e n c e transductive \ inference transductive inference

2 Related Work

2.1 Few-Shot Learning for classification

元学习已经成为从小样本中学习新任务的事实上的解决方案,这些工作可以分为基于梯度(gradient)学习和基于度量(metric)学习的两类方法

  • 梯度方法采用随机梯度下降(stochastic gradient descent,SGD)来学习不同任务之间的共性
  • 度量学习方法采用深度网络作为特征嵌入(feature-embedding)函数,并比较嵌入之间的距离

在最近的一系列工作中,已研究了用于小样本分类的 t r a n s d u c t i v e   s e t t i n g transductive \ setting transductive setting并获得了优于归纳推理的性能改进。

  • 采用推到推理的工作大多遵循半监督学习策略
  • 熵(entropy)作为传导损失的一部分,但对分割任务是不充足的【熵最小化,entropy minimization】

2.2 Few-shot segmentation

分割可以被视为像素级的分类,最近的工作主要集中在专门架构的设计上

现有的方法灵感大多来自于原型网络

  • 早期架构是使用双分支比较(two-branch comparison)框架:一个从 s u p p o r t   i m a g e support \ image support image使用嵌入函数生成原型,另一个用学习到的原型分割查询图像——也会先使用嵌入函数
  • 最近,双分支设置被统一为单个分支,对支持集和查询集使用相同的嵌入函数

注: 这些方法的目的在于学习更好的特定类的表示(class-specific representation)或 迭代改进已学习的表示

3 Formulation

3.1 Few-shot Setting

训练阶段: 使用基类数据集: D b a s e \mathcal{D}_{base} Dbase,包含基语义类(with base semantic classes) Y b a s e \mathcal{Y}_{base} Ybase

D = ( x n , y n ) n = 1 N \mathcal{D} = {(x_n,y_n)}_{n=1}^N D=(xn,yn)n=1N, Ω ⊂ R 2 \varOmega \subset \mathbb{R}^2 ΩR2是一个图像空间, x n : Ω → R 3 x_n: \varOmega \to \mathbb{R}^3 xn:ΩR3是一个输入图像, y n : Ω → { 0 , 1 } ∣ Y b a s e ∣ y_n:\varOmega \to \{ 0,1\}^ {|\mathcal{Y}_{base}|} ynΩ{0,1}Ybase 是该图像的像素级独热编码(pixelwise one-hot)

测试阶段: 使用一些列 K − s h o t s K-shots Kshots任务,每个任务包含一个完全打标的support set S = ( x k , y k ) k = 1 K S = {(x_k,y_k)}_{k=1}^K S=(xk,yk)k=1K 和一个没有标签的 query image x Q x_Q xQ

  • 没有标签的图像都来自同一个新类(novel class),且该新类是从一组新类 Y n o v e l \mathcal{Y}_{novel} Ynovel中随机抽样得到
  • Y b a s e ∩ Y n o v e l = ∅ \mathcal{Y}_{base} \cap \mathcal{Y}_{novel} = \empty YbaseYnovel=——目标是利用支持集提供的监督,以便在查询图像中正确地分割感兴趣的对象

3.2 Base training

3.2.1 Inductive bias in episodic training

存有多种利用 D b a s e \mathcal{D}_{base} Dbase的方法,普遍使用的 元学习通过将 D b a s e \mathcal{D}_{base} Dbase构造为一系列训练任务来模拟训练期间的测试时场景,然后在这些任务上训练模型,以学习如何最好地利用支持集的监督

A theoreti-cal analysis of the number of shots in few-shot learning 证明:以原型网络为案例,训列阶段中的 K t r a i n K_{train} Ktrain(样本个数)表示一个 l e a r n i n g   b i a s learning \ bias learning bias,且当 K t r a i n ! = K t e s t K_{train} != {K_test} Ktrain!=Ktest时,测试的性能会迅速饱和

3.2.2 Standard training

测试任务应该是未知的,希望在这个问题上少做假设,所以在 D b a s e \mathcal{D}_{base} Dbase使用一个经标准交叉熵(cross-entropy)监督的特征提取器 f ∅ f_{\empty} f,而不是诉诸情景训练——【不要像元学习那样在训练时模拟测试

3.3 Inference

In what follows, we use . . . as a placeholder to denote either a support subscript k ∈ {1, …, K} or the query subscript Q.——下标点表示支持集或查询集

3.3.1 Objective

s u p p o r t support support q u e r y query query图像都会提取出特征 z . : = f ∅ ( x . ) z. := f_\empty(x.) z.:=f(x.) , z . : Ψ → R C z. : \varPsi \to \mathbb{R}^C z.:ΨRC,其中, C C C是特征空间 Ψ \varPsi Ψ里的通道维度,且| Ψ \varPsi Ψ| < | Ω \varOmega Ω|

目标: 使用特征 z . z. z.来学习分类器 p . p. p. 的参数 θ \theta θ以正确区分前景和背景像素,其中, p . : Ψ → [ 0 , 1 ] 2 p. : \varPsi \to [0,1]^2 p.:Ψ[0,1]2((二分类))会为提取的特征空间中的每个像素分配一个(B/F,background/foreground)概率向量


对于每个测试任务,通过优化一下传导目标来找到分类器的参数 θ \theta θ,其中 λ \lambda λ是平衡两个损失的非负超参:
min ⁡ θ C E + λ H H + λ K L D K L \min \limits_{\theta} CE+\lambda_{\mathcal{H}}{\mathcal{H}} + \lambda_{KL}\mathcal{D}_{KL} θminCE+λHH+λKLDKL

  • CE 是 来自支持图像的下采样后的标签 y k ~ \tilde{y_k} yk~ 与 分类器的软预测(soft prediction)之间的交叉熵;仅最小化这一项会导致过拟合
    在这里插入图片描述
    注: 软标签是概率值,(0,1)区间上的值;而硬标签是0或1

  • H \mathcal{H} H 是查询图像像素的预测值的香农熵(Shannon entropy);这一项是为了让模型在预测时更加自信(上一项是soft预测)。直观上,它将线性分类器的决策边界推向了query上提取的特征空间的低密度区(让预测值向0和1靠拢,变成硬标签?),但尽管这一项的作用很重要,可如果仅直接将它加在CE loss上并没用
    在这里插入图片描述

  • D K L \mathcal{D}_{KL} DKL是KL散度,其中 P Q ^ = 1 ∣ Ψ ∣ T ∑ j ∈ Ψ P Q ( j ) \widehat{P_Q}={1 \over |\varPsi|}^T \sum_{j \in \varPsi}P_Q(j) PQ =Ψ1TjΨPQ(j) , 鼓励模型预测的B/F比例匹配参数 π ∈ [ 0 , 1 ] 2 π∈[0,1]^2 π[0,1]2;当预测出的b/f比例与π不同时,该项可以防止模型陷入前两个loss中的过拟合
    在这里插入图片描述

3.3.2 Choice of the classifier

目标: 当我们在推理时为每个任务优化θ时,我们希望我们的方法添加尽可能少的计算负荷

想法: 对支持集和查询集使用一个相同的简单线性分类器,该分类器具有可学习参数 θ ( t ) = { w ( t ) , b ( t ) } \theta^{(t)}= \{ w^{(t)},b^{(t)} \} θ(t)={w(t),b(t)} t t t是优化过程的当前步骤, w ( t ) ∈ R C w^{(t)} \in \mathbb{R}^C w(t)RC f o r e g r o u n d   p r o t o t y p e foreground \ prototype foreground prototype, b ( t ) ∈ R b^{(t)}\in \mathbb{R} b(t)R是相应的偏置

在这里插入图片描述
其中, s . ( t ) ( j ) = s i g m o i d ( τ [ c o s ( z . ( j ) , w ( t ) ) − b ( t ) ] ) s.^{(t)}(j)=sigmoid(\tau[cos(z.(j),w^{(t)})−b^{(t)}]) s.(t)(j)=sigmoid(τ[cos(z.(j),w(t))b(t)]) τ ∈ R \tau \in \mathbb{R} τR是温度超参

  • w w w:第一次迭代的初始化设置为support中所有图片中feature map上前景点像素的占比
  • b b b:第一次迭代的初始化设置为query上的soft pred的均值
  • 随后以梯度下降进行优化

在这里插入图片描述

3.3.3 Joint estimation of B/F proportion π π π

  • 无额外信息,利用模型在query image上标签边缘分布,从而在训练分类器时共同学习 π π π
  • 模型的inference可以看作是 θ θ θ π π π的共同训练
  • KL散度项作为自正则化,可以防止模型的边缘分布产生偏差


  • 最小化前面的第一个公式(两个熵+一个散度),可以发现 π = p Q ^ π = \widehat{p_Q} π=pQ , p Q ^ \widehat{p_Q} pQ 是query图像的特征图上的各个像素分类预测值在整个图上的平均值【看前面的公式】—— 意味着更新分类器参数时用到了query图像的数据信息,体现了 t r a n s d u c t i v e transductive transductive思想——支持集和查询集公用一个分类器,分类器的更新用到了查询集样本信息,即在训练时用到了查询集
    在这里插入图片描述

注: inductive,transductive

  • 作者经验发现, π π π更新的也不需要很频繁,一般初始化为第一次迭代的得到的 p Q ^ \widehat{p_Q} pQ ,然后在中间再找一次迭代出来的 p Q ^ \widehat{p_Q} pQ 更新一次 π π π。只更新一次即可

3.3.4 Oracle case with a known π π π

π π π估计的不可能很准,作为方法的上限,可以用统计出来的前景点比例来设置这个pi,并不让它更新
在这里插入图片描述

O r a c l e   r e s u l t s Oracle \ results Oracle results:

  • 证明了存在一个简单的线性分类器,它可以在很大程度上优于最先进的元学习模型,而且该分类器建立在使用标准交叉熵损失训练的特征提取器之上
  • 表明了如果可以拥有目标的某些精确信息(如前背景之比),那这一项就可以用来当作一个强大的正则化项(strong regularizer)
  • 这表明,在适当约束w和b的优化过程方面可以做出更多的努力,并为有前途的途径打开了一扇门

4 Experiments

4.2 Domain shift

引入了一个更现实的跨域设置:训练在一个数据集,测试在另一个数据集

  • 这样的设置是迈向对这些方法进行更现实的评估的一步,因为它可以评估数据训练分布和测试分布之间的域转移对性能的影响
  • 我们相信这种情况在实践中很容易发现,因为即使是数据收集过程中的微小变化也可能导致分布转移
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

明前大奏

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值