论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》

小样本学习&元学习经典论文整理||持续更新

核心思想

  本文提出一种基于数据增强的小样本学习算法(AFHN),利用生成对抗网络(GAN)实现数据集的扩充。数据增强的方法被认为可以增强类内样本方差的多样化,从而实现更加清晰地分类界限。先前的数据增强方法主要包含两类:一类是通过在基础数据集上学习一种变换映射,并将其直接应用到新的数据集上,得到映射后的合成图像用于数据扩充,这一类方法会破坏合成图像的区分能力(因为合成图像很粗糙,与原始类别并不相似);另一类方法是根据特定的任务生成对应的合成图像,这类方法保证了合成图像的区分能力,但特定的任务约束使得合成的图像容易陷入一种特定的模式,从而丧失了多样性(在GAN中这种情况称之为Mode Collapse,就是指生成的图像之间太过于相似,不具备多样性)。本文利用conditional Wasserstein Gener- ative Adversarial Networks ,cWGAN(与普通的GAN相比,cWGAN就是通过改进目标函数,进而提高训练稳定性的一个变种,此处不再详细介绍)生成样本,并通过增加分类正则项(classification regularizer)和 “反陷入”正则项(anti-collapse regularizer),解决了生成样本缺少区分能力和多样性的问题。本文提出算法的处理流程如下图
在这里插入图片描述
  首先支持集图像和查询集图像经过特征提取网络 F F F得到对应的特征向量,支持集对应的特征向量为 s s s(如果有多个样本则取平均值),从[0,1]的均匀分布中采样得到两个随机变量 z 1 , z 2 z_1,z_2 z1,z2。然后将特征向量 s s s z 1 , z 2 z_1,z_2 z1,z2输入到cWGAN的生成器 G G G中,得到合成的向量 s ~ 1 , s ~ 2 \tilde{s}_1,\tilde{s}_2 s~1,s~2,过程如下
在这里插入图片描述
将生成的 s ~ 1 , s ~ 2 \tilde{s}_1,\tilde{s}_2 s~1,s~2与原始的 s s s z 1 , z 2 z_1,z_2 z1,z2输入到区分器 D D D中,并计算GAN损失 L G A N {L}_{GAN} LGAN,过程如下
在这里插入图片描述
  而单纯的GAN损失并不能解决生成样本缺少区分能力和多样性的问题,因此本文又设计了两个正则化项:分类正则项(classification regularizer)和 “反陷入”正则项(anti-collapse regularizer)。其中分类正则项很好理解,首先利用softmax函数根据生成的样本 s ~ \tilde{s} s~得到查询样本 x q x_q xq对应类别的概率,计算过程如下
在这里插入图片描述
式中 q = F ( x q ) q=F(x_q) q=F(xq),然后再利用交叉熵损失函数计算分类损失,作为分类正则项 L c r i L_{cr_i} Lcri,该正则项的目的是为了增强生成样本的区分能力
在这里插入图片描述
而“反陷入”正则项则是直接对两个合成特征向量的不相似度和产生它们的两个噪声向量的不相似度的比值进行惩罚,文字表述比较复杂,我们直接看公式
在这里插入图片描述
式中,分子部分表示了两个合成特征向量之间的不相似度,而分母表示两个噪声向量之间的不相似度。有研究表明 z 1 z_1 z1 z 2 z_2 z2越相似,则 s ~ 1 \tilde{s}_1 s~1 s ~ 2 \tilde{s}_2 s~2越容易陷入同一种模式。当 z 1 z_1 z1 z 2 z_2 z2很相似时,也就是分母很小时,上式则相当于放大了 s ~ 1 \tilde{s}_1 s~1 s ~ 2 \tilde{s}_2 s~2之间的不相似度(因为要除以一个远小于1的数字)。该正则项的目的时为了增强生成样本的多样性。
  最后,将生成的样本 s ~ \tilde{s} s~与原始样本 s s s一起输入到分类器 C C C中,进而实现对于查询样本 x q x_q xq的分类。

实现过程

网络结构

  特征提取网络采用ResNet网络,生成器和区分器均采用带有Leaky ReLU激活函数的两层MLP网络。

损失函数

  对于生成对抗网络部分损失函数如下
在这里插入图片描述
值得注意的是“反陷入”正则项 L a r L_{ar} Lar取了倒数,因此对于生成器而言是希望生成的 s ~ 1 \tilde{s}_1 s~1 s ~ 2 \tilde{s}_2 s~2之间的不相似度越大越好。
  对于分类器部分采用简单的分类损失函数进行训练
在这里插入图片描述

训练策略

  本文的训练过程如下
在这里插入图片描述

创新点

  • 本文利用cWGAN网络生成样本,用于数据集扩充,改善小样本分类效果
  • 设计了两个正则化项,提高了生成样本的区分能力和多样性

算法评价

  本文还是比较标准的采用GAN生成样本,进而实现数据增强的算法。这一类方法通常因为样本太少,导致生成的样本效果太差,而无法起到数据增强的效果。而本文通过采用稳定性更好的cWGAN算法,并设计两个正则化项,改善了生成样本的效果,使其能够应用于小样本学习算法。

如果大家对于深度学习与计算机视觉领域感兴趣,希望获得更多的知识分享与最新的论文解读,欢迎关注我的个人公众号“深视”。在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

深视

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值