基于飞桨复现 CVPR2018 Relation Net的全程解析

本文通过飞桨框架详细解析并复现了CVPR2018论文Relation Net,该模型解决深度学习中小样本学习问题,通过端到端训练提高分类效果。模型由embedding model和relation model组成,采用episode based training策略进行训练,最终在minImageNet数据集上达到预期准确率。
摘要由CSDN通过智能技术生成

 

【飞桨开发者说】佟兴宇,北京航空航天大学硕士,机器视觉算法工程师。

Relation Net 是 CVPR2018的一篇论文,论文链接:

https://arxiv.org/pdf/1711.06025.pdf

论文作者发现,在视觉识别任务中,训练模型时需要大量标注过的图片,并迭代多次去训练参数。每当新增物体类别,都需要花费大量时间去标注,但是有一些新兴物体类别和稀有物体类别可能不存在大量标注过的图片,从而影响模型训练效果。反观人类,只要很少的认知学习就可实现小样本(FSL)和无样本学习(ZSL)。

比如:小孩子只要在一张图片或一本书里认识了斑马,或者只是听到描述斑马是一种”条纹马”,就可以毫无困难的识别出斑马这种动物。为了解决深度学习中模型样本少导致的分类效果差的问题,同时又受到人类的小样本和无样本学习能力带来的启发,小样本学习又恢复了一些热度。

深度学习中的Fine-tune技术可以用于一些样本比较少的情况,但是在只有一个或者几个样本的情况下,即使使用了数据增强和正则化技术,仍然会有过拟合的问题。目前其他的小样本学习的推理机制又比较复杂,所以论文作者提出了一个可以端到端训练,并且结构简单的模型Relation Net。

在 FSL 任务中,一般将数据集分为 Training set 、Support set 、Testing set。Support set和 Testing set有共同的标签;Training set里面不包涵 Support set和 Testing set的标签;在 Support set 中有 K 个标注过的数据和C个不同的类别,则称作为 C-way K-shot。在训练的过程中从 Training set 中选取 sample set /query set 对应Support set / Testing set,具体方法在文中的训练策略里会详细说明。

Relation Network由 embedding model 和 relation model 组成。Relation Network 的核心思想是:首先通过embedding model分别提取 support set 和 testing set中图像的特征图,然后将特征图中代表通道数的维度进行拼接,得到一个新的特征图。然后把新的特征图送入 relation model 进行运算得到 relation score,这个值代表了两张图的相似度。

下图为5-way 1-shot 的情况下接受1个样本的网络结构与流程。5张sample set 中的图片与1张 query set 中的图片会分别的通过 embedding model 提取特征并拼接,得到5个新的特征图,然后送入 Relation Net 进行计算 relation score,最后会得到一个 one-shot 的向量,分数最高的代表对应的类别。

训练使用的损失函数也比较简单,使用均方误差作为损失函数。公式中 ri,j代表图片 i与 j 的相似度。yi 与 yj代表图片的真实标签。

基于飞桨复现

Relation Network

下面我将复现的技术细节与各位开发者分享,Relation Network 模型结构定义请查看:

https://github.com/txyugood/paddle_RN_FSL/blob/master/RelationNet.py

1. 搭建 Relation Network 网络

模型由embedding model 和 relation model 两部分组成,两个网络都主要由 【Conv+BN+Relu】 模块组成。因此先定义一个 BaseNet类,并在其中实现conv_bn_layer方法,代码如下:

class BaseNet:
    def conv_bn_layer(self,
                      input,
                      num_filters,
                      filter_size,
                      stride=1,
                      groups=1,
                      padding=0,
                      act=None,
                      name=None,
                      data_format='NCHW'):
        n = filter_size * filter_size * num_filters
   
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值