Learning to Propagate Labels:Transductive Propagation Network for Few-shot Learning 阅读笔记+代码

目录

一、Abstract

二、Introduction

 三、Contributions

四、Method

 1、问题定义

2、转导传播网络

2.1、特征嵌入

2.2、图构建

2.3、标签传播

2.4 损失 

五、实验


一、Abstract

小样本学习的目标是学习一个分类器,即使在每个类的训练实例数量有限的情况下也能很好地概括。最近引入的元学习方法解决了这个问题,它通过在大量的多类分类任务中学习一个通用分类器,并将模型推广到一个新任务中。然而,即使有了这种元学习,新的分类任务中的低数据问题仍然存在。

在本文中,我们提出了一种新的转导传播网络(TPN),它是一种新颖的元学习框架,可以对整个测试集进行一次性分类,以缓解低数据量的问题。

具体来说,我们建议通过学习利用数据中的流形结构的图构造模块来学习如何将标签从标记的实例传播到未标记的测试实例。TPN以端到端方式联合学习特征嵌入参数和图构造参数。我们在多个基准数据集上验证了TPN,它在很大程度上优于现有的小样本学习方法,并实现了最先进的结果。




二、Introduction

少量学习的目的是学习一个分类器,它可以很好地概括这些类中的每个类的几个例子。传统的技术,如微调,与深度学习模型工作良好,但在这个任务中会严重过拟合,因为单个或只有几个标签实例不能准确地代表真实的数据分布,将导致学习的分类器具有高方差,这将不能很好地推广到新数据。

为了解决这一过拟合问题,Vinyalset al.(2016)提出了一种元学习策略,即在大量的eposides中对不同的分类任务进行学习,而不是只对目标分类任务进行学习。在每一eposide中,算法学习少量标注示例(支持集)的嵌入情况,通过嵌入空间中的距离来预测未标注点(查询集)的类。情景训练的目的是模拟包含少量支持集和无标签查询集的真实测试环境。

训练和测试环境的一致性缓解了分布差距,提高了泛化能力。

这种情景元学习策略,由于其泛化性能,已被许多后续工作的小样本学习调整。Finnet al.(2017)学习了一个很好的初始化,可以快速适应目标任务。Snellet al.(2017)使用集来训练一个好的表示,并通过计算关于类原型的欧氏距离来预测类。

虽然情景策略是一种有效的学习方法,因为它的目标是将可见的分类任务推广到不可见的分类任务,但对于一个新的分类任务来说,在缺乏数据的情况下学习仍然存在根本的困难。用有限的训练数据实现更大改进的一种方法是考虑测试集中实例之间的关系,并以此作为一个整体来预测它们,这被称为转导,或转导推理。

在之前的工作中,特别是在小的训练集中,转导推理已经显示出优于预测一个接一个的测试例子的归纳方法。一种常用的转导方法是在有标记的和无标记的数据上构造一个网络,并在它们之间传播标记以进行联合预测。然而,这种标签传播(和转导)的主要挑战是,标签传播网络常常在没有考虑主要任务的情况下获得,因为在测试时不可能学习它们

通过情景训练的元学习,我们可以学习标签传播网络,因为从训练集中抽样的查询示例可以用来模拟真实的测试集进行转导推理。基于这一发现,我们提出了转导传播网络(TPN)来解决低数据量问题。我们没有应用归纳推理,而是利用整个查询集进行转导推理(参见图1)。

具体来说,我们首先使用深度神经网络将输入映射到嵌入空间。然后,利用支持集和查询集的并集,提出了一个图构造模块来利用新类空间的流形结构。根据图的结构,使用迭代标签传播将标签从支持集传播到查询集,最终得到封闭解。利用查询集的传播分数和ground truth标签,计算关于特征嵌入和图构造参数的交叉熵损失。最后,所有参数都可以使用反向传播进行端到端更新。


 三、Contributions

  1. 据我们所知,我们是第一个在小样本学习中明确地模拟转导推理的人。虽然Nicholet al.(2018)实验了一个传导设置,但他们只是通过批次标准化在测试示例之间共享信息,而不是直接提出一个传导模型。
  2. 在转换推理中,我们提出通过情景元学习去学习在未见类的数据实例之间传播标签。这种学习的标签传播图被证明显著优于朴素的启发式的标签传播方法。
  3. 我们在两个基准数据集(miniImageNet和tieredimagenet)上评估了我们的方法。实验结果表明,我们的转导传播网络优于最先进的方法在两个数据集。此外,在半监督学习下,我们的算法实现了更高的性能,优于所有的半监督小样本学习基线。

四、Method


 1、问题定义

由eposide训练实现的元学习在少量分类任务中表现良好。然而,由于支持集中缺少标记实例(K通常很小),我们观察到仍然很难获得一个可靠的分类器。这促使我们考虑利用整个查询集进行预测的转换设置,而不是独立地预测每个示例。考虑到整个查询集,我们可以缓解低数据问题,并提供更可靠的泛化属性。


2、转导传播网络

我们引入了如图2所示的转导传播网络(TPN),它由四个部分组成:特征嵌入卷积神经网络;生成实例参数的图构造,以开发流形结构;标签传播,将标签从支持集S扩散到查询集Q;一个损耗生成步骤,计算传播标签和Q上的ground-truth之间的交叉熵损耗,从而联合训练框架中的所有参数。


2.1、特征嵌入

 代码实现如下:

四层以下网络

 


2.2、图构建

为了在元学习中获得合适的邻域图,我们提出了一个基于支持集和查询集的并集的图构造模块:S∪Q。该模块由卷积神经网络g(φ)组成,g(φ)取xi∈S∪Q的特征映射fϕ(xi),生成一个样例长度尺度参数σi=gφ(fϕ(xi))。值得注意的是,尺度参数是通过实例明智地确定的,并且是通过情景性训练程序学习的,这可以很好地适应不同的任务,并适合于少量的学习。通过例子σi,我们的相似函数定义如下:

inp=torch.cat((support,query),0)
emb_all=self.encoder(inp).view(-1,1600) # 特征提取 
N,d=emb_all.shape[0],emb_all.shape[1]

#Step2:GraphConstruction
##sigmma
if self.args['rn']in[30,300]:
    self.sigma=self.relation(emb_all,self.args['rn']) # Graph construction structure

##W
emb_all=emb_all/(self.sigma+eps)#N*d
emb1=torch.unsqueeze(emb_all,1)#N*1*d
emb2=torch.unsqueeze(emb_all,0)#1*N*d
W=((emb1-emb2)**2).mean(2)#N*N*d->N*N
W=torch.exp(-W/2)#构造的图的边相似性

提出的图构造模块的结构如图3所示。它由两个卷积块和两个完全连接的层组成,每个块包含一个3 × 3卷积、批处理归一化、ReLU激活,然后是2 × 2 max pooling。每个卷积块中的过滤器数量分别为64个和1个。为了提供一个示例式的缩放参数,第二个卷积块的激活映射通过两个完全连接的层(其中神经元数量分别为8和1)转换为一个标量。

class RelationNetwork(nn.Module):
"""GraphConstructionModule"""
def__init__(self):
	super(RelationNetwork,self).__init__()
	self.layer1=nn.Sequential(
		nn.Conv2d(64,64,kernel_size=3,padding=1),
		nn.BatchNorm2d(64),
		nn.ReLU(),
		nn.MaxPool2d(kernel_size=2,padding=1))
		
	self.layer2=nn.Sequential(
		nn.Conv2d(64,1,kernel_size=3,padding=1),
		nn.BatchNorm2d(1),
		nn.ReLU(),
		nn.MaxPool2d(kernel_size=2,padding=1))
		
	self.fc3=nn.Linear(2*2,8)
	self.fc4=nn.Linear(8,1)
	
	self.m0=nn.MaxPool2d(2)#max-poolwithoutpadding
	self.m1=nn.MaxPool2d(2,padding=1)#max-poolwithpadding

def forward(self,x,rn):
	x=x.view(-1,64,5,5)
	out=self.layer1(x)
	out=self.layer2(out)
	
	#flatten
	out=out.view(out.size(0),-1)
	out=F.relu(self.fc3(out))
	out=self.fc4(out)#norelu
	out=out.view(out.size(0),-1)#bs*1
	
	return out

我们遵循情景范式进行少量的元学习者训练。这意味着该图是为每个episode中的每个任务单独构建的,如图1所示。通常,在5-way 5-shot训练中,N= 5, K= 5, T= 75, W的维数仅为100×100,这是非常有效的。


2.3、标签传播

 我们现在描述如何在最后一个交叉熵损失步骤之前,使用标签传播获得查询集Q的预测。设F表示具有非负项的(N×K+T)×N矩阵的集合。

具体细节:以 5-ways 5-shots为例,查询集为,每类15个样本

1、输入数据:

     support:[25,3,84,84]

    s_onehot:[25,5]

    query:[75,3,84,84]

    q_onehot:[75,5]

2、进行特征提取,以及计算图的边的相似性W

3、取W每行最大的top-k,将其保留下来,其余为0

4、对W进行正则化

D=W.sum(0)

D_sqrt_inv=torch.sqrt(1.0/(D+eps))

D1=torch.unsqueeze(D_sqrt_inv,1).repeat(1,N)

D2=torch.unsqueeze(D_sqrt_inv,0).repeat(N,1)

S=D1*W*D2

5、Y为支持集标签的one-hot,和查询集的one-hot(全为0,需要进行预测的)

ys=s_labels

yu=torch.zeros(num_classes*num_queries,num_classes)

#yu=(torch.ones(num_classes*num_queries,num_classes)/num_classes).cuda(0)

y=torch.cat((ys,yu),0)

6

torch.matmul(torch.inverse(torch.eye(N)-self.alpha*S+eps),y)

2.4 损失 

这一步的目的是计算通过标签传播对支持和查询集并集的预测与ground-truths之间的分类损失。我们计算预测分数F∗和S∪Q中ground truth标签之间的交叉熵损失,以端到端方式学习所有参数,其中F∗使用softmax转换为概率分数:

五、实验

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值