Dynamic Few-Shot Visual Learning without Forgetting阅读笔记

Dynamic Few-Shot Visual Learning without Forgetting

研究问题

从少量样本中学习到新的概念,本文旨在设计一个小样本视觉学习系统。该系统能够在测试阶段从少量训练样本中高效地学习新的概念,于此同时不会忘记原始的类别

创新点

  1. 提出基于注意力机制的小样本类别权重生成器

  2. 将分别器中的点积操作使用余弦相似度操作,具体而言, 分类器计算的是特征表征与分类权重向量之间的余弦相似度

    为什么使用余弦相似度?

    普通的dot-product operation,新类计算得到的幅值变化比较大,与基类会呈现出明显的区别。会造成两个问题

    1. 阻碍训练
    2. 基类与新类呈现较大差别不符合实际情况

问题设定

  1. 基类包含大量的训练数据

  2. 在不忘记基类或重新训练的情况下,不仅能够识别基类,还能识别新类,新类样本仅在测试阶段提供。

研究方法

分两阶段训练

  1. 阶段一:使用基类数据进行普通的训练,得到特征提取器和分类器的参数

    疑问:该阶段的分类器是使用dot-product还是余弦相似度

  2. 阶段二:冻结特征提取器的参数,学习权重生成器的参数

在这里插入图片描述

本文最大创新点——分类权重生成器如何实现

  1. 从新类的少量样本中获得特征向量,并将获得的特征向量与原有的基类权重一起输入到权重生成器中获得用于新类的分类权重 w n ′ = G ( Z n ′ , W b a s e ∣ ϕ ) w_n'=G(Z_n', W_{base}|\phi) wn=G(Zn,Wbaseϕ)其中 ϕ \phi ϕ是可学习的参数

  2. ϕ \phi ϕ如何学习?

    • 取新类样本特征向量的平均值 w a v g ′ = 1 N ′ ∑ i = 1 N ′ z ˉ i ′ w_{a v g}^{\prime}=\frac{1}{N^{\prime}} \sum_{i=1}^{N^{\prime}} \bar{z}_{i}^{\prime} wavg=N1i=1Nzˉi

    • 使用可学习参数把特征向量转换为query vector, k b k_b kb是可学习的keys
      w a t t ′ = 1 N ′ ∑ i = 1 N ′ ∑ b = 1 K b a s e A t t ( ϕ q z ˉ i ′ , k b ) ⋅ w ˉ b w_{a t t}^{\prime}=\frac{1}{N^{\prime}} \sum_{i=1}^{N^{\prime}} \sum_{b=1}^{K_{b a s e}} A t t\left(\phi_{q} \bar{z}_{i}^{\prime}, k_{b}\right) \cdot \bar{w}_{b} watt=N1i=1Nb=1KbaseAtt(ϕqzˉi,kb)wˉb

## talk is cheap, show me the code
class AttentionBasedBlock(nn.Module):
    def __init__(self, nFeat, nK, scale_att=10.0):
        super(AttentionBasedBlock, self).__init__()
        self.nFeat = nFeat
        self.queryLayer = nn.Linear(nFeat, nFeat)
        self.queryLayer.weight.data.copy_(
            torch.eye(nFeat, nFeat) + torch.randn(nFeat, nFeat)*0.001)
        self.queryLayer.bias.data.zero_()

        self.scale_att = nn.Parameter(
            torch.FloatTensor(1).fill_(scale_att), requires_grad=True)
        wkeys = torch.FloatTensor(nK, nFeat).normal_(0.0, np.sqrt(2.0/nFeat))
        self.wkeys = nn.Parameter(wkeys, requires_grad=True)


    def forward(self, features_train, labels_train, weight_base, Kbase):
        batch_size, num_train_examples, num_features = features_train.size()
        nKbase = weight_base.size(1) # [batch_size x nKbase x num_features]
        labels_train_transposed = labels_train.transpose(1,2)
        nKnovel = labels_train_transposed.size(1) # [batch_size x nKnovel x num_train_examples]

        features_train = features_train.view(
            batch_size*num_train_examples, num_features)
        Qe = self.queryLayer(features_train)
        Qe = Qe.view(batch_size, num_train_examples, self.nFeat)
        Qe = F.normalize(Qe, p=2, dim=Qe.dim()-1, eps=1e-12)

        wkeys = self.wkeys[Kbase.view(-1)] # the keys of the base categoreis
        wkeys = F.normalize(wkeys, p=2, dim=wkeys.dim()-1, eps=1e-12)
        # Transpose from [batch_size x nKbase x nFeat] to
        # [batch_size x self.nFeat x nKbase]
        wkeys = wkeys.view(batch_size, nKbase, self.nFeat).transpose(1,2)

        # Compute the attention coeficients
        # batch matrix multiplications: AttentionCoeficients = Qe * wkeys ==>
        # [batch_size x num_train_examples x nKbase] =
        #   [batch_size x num_train_examples x nFeat] * [batch_size x nFeat x nKbase]
        AttentionCoeficients = self.scale_att * torch.bmm(Qe, wkeys)
        AttentionCoeficients = F.softmax(
            AttentionCoeficients.view(batch_size*num_train_examples, nKbase))
        AttentionCoeficients = AttentionCoeficients.view(
            batch_size, num_train_examples, nKbase)

        # batch matrix multiplications: weight_novel = AttentionCoeficients * weight_base ==>
        # [batch_size x num_train_examples x num_features] =
        #   [batch_size x num_train_examples x nKbase] * [batch_size x nKbase x num_features]
        weight_novel = torch.bmm(AttentionCoeficients, weight_base)
        # batch matrix multiplications: weight_novel = labels_train_transposed * weight_novel ==>
        # [batch_size x nKnovel x num_features] =
        #   [batch_size x nKnovel x num_train_examples] * [batch_size x num_train_examples x num_features]
        weight_novel = torch.bmm(labels_train_transposed, weight_novel)
        weight_novel = weight_novel.div(
            labels_train_transposed.sum(dim=2, keepdim=True).expand_as(weight_novel))

        return weight_novel

使用的数据集

  1. Mini-ImageNet

  2. Bharath and Girshick [B. Hariharan and R. Girshick. Low-shot visual recognition by shrinking and hallucinating features. arXiv preprint arXiv:1606.02819, 2016.]

结论

在Mini-Image上测评,1-shot能够达到56.2%的准确率,5-shot能够达到73.0%的准确率。与此同时没有损失任何在基类上的准确率

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值