Dynamic Few-Shot Visual Learning without Forgetting
研究问题
从少量样本中学习到新的概念,本文旨在设计一个小样本视觉学习系统。该系统能够在测试阶段从少量训练样本中高效地学习新的概念,于此同时不会忘记原始的类别
创新点
-
提出基于注意力机制的小样本类别权重生成器
-
将分别器中的点积操作使用余弦相似度操作,具体而言, 分类器计算的是特征表征与分类权重向量之间的余弦相似度
为什么使用余弦相似度?
普通的dot-product operation,新类计算得到的幅值变化比较大,与基类会呈现出明显的区别。会造成两个问题
- 阻碍训练
- 基类与新类呈现较大差别不符合实际情况
问题设定
-
基类包含大量的训练数据
-
在不忘记基类或重新训练的情况下,不仅能够识别基类,还能识别新类,新类样本仅在测试阶段提供。
研究方法
分两阶段训练
-
阶段一:使用基类数据进行普通的训练,得到特征提取器和分类器的参数
疑问:该阶段的分类器是使用dot-product还是余弦相似度
-
阶段二:冻结特征提取器的参数,学习权重生成器的参数
本文最大创新点——分类权重生成器如何实现
-
从新类的少量样本中获得特征向量,并将获得的特征向量与原有的基类权重一起输入到权重生成器中获得用于新类的分类权重 w n ′ = G ( Z n ′ , W b a s e ∣ ϕ ) w_n'=G(Z_n', W_{base}|\phi) wn′=G(Zn′,Wbase∣ϕ)其中 ϕ \phi ϕ是可学习的参数
-
ϕ \phi ϕ如何学习?
-
取新类样本特征向量的平均值 w a v g ′ = 1 N ′ ∑ i = 1 N ′ z ˉ i ′ w_{a v g}^{\prime}=\frac{1}{N^{\prime}} \sum_{i=1}^{N^{\prime}} \bar{z}_{i}^{\prime} wavg′=N′1∑i=1N′zˉi′
-
使用可学习参数把特征向量转换为query vector, k b k_b kb是可学习的keys
w a t t ′ = 1 N ′ ∑ i = 1 N ′ ∑ b = 1 K b a s e A t t ( ϕ q z ˉ i ′ , k b ) ⋅ w ˉ b w_{a t t}^{\prime}=\frac{1}{N^{\prime}} \sum_{i=1}^{N^{\prime}} \sum_{b=1}^{K_{b a s e}} A t t\left(\phi_{q} \bar{z}_{i}^{\prime}, k_{b}\right) \cdot \bar{w}_{b} watt′=N′1i=1∑N′b=1∑KbaseAtt(ϕqzˉi′,kb)⋅wˉb
-
## talk is cheap, show me the code
class AttentionBasedBlock(nn.Module):
def __init__(self, nFeat, nK, scale_att=10.0):
super(AttentionBasedBlock, self).__init__()
self.nFeat = nFeat
self.queryLayer = nn.Linear(nFeat, nFeat)
self.queryLayer.weight.data.copy_(
torch.eye(nFeat, nFeat) + torch.randn(nFeat, nFeat)*0.001)
self.queryLayer.bias.data.zero_()
self.scale_att = nn.Parameter(
torch.FloatTensor(1).fill_(scale_att), requires_grad=True)
wkeys = torch.FloatTensor(nK, nFeat).normal_(0.0, np.sqrt(2.0/nFeat))
self.wkeys = nn.Parameter(wkeys, requires_grad=True)
def forward(self, features_train, labels_train, weight_base, Kbase):
batch_size, num_train_examples, num_features = features_train.size()
nKbase = weight_base.size(1) # [batch_size x nKbase x num_features]
labels_train_transposed = labels_train.transpose(1,2)
nKnovel = labels_train_transposed.size(1) # [batch_size x nKnovel x num_train_examples]
features_train = features_train.view(
batch_size*num_train_examples, num_features)
Qe = self.queryLayer(features_train)
Qe = Qe.view(batch_size, num_train_examples, self.nFeat)
Qe = F.normalize(Qe, p=2, dim=Qe.dim()-1, eps=1e-12)
wkeys = self.wkeys[Kbase.view(-1)] # the keys of the base categoreis
wkeys = F.normalize(wkeys, p=2, dim=wkeys.dim()-1, eps=1e-12)
# Transpose from [batch_size x nKbase x nFeat] to
# [batch_size x self.nFeat x nKbase]
wkeys = wkeys.view(batch_size, nKbase, self.nFeat).transpose(1,2)
# Compute the attention coeficients
# batch matrix multiplications: AttentionCoeficients = Qe * wkeys ==>
# [batch_size x num_train_examples x nKbase] =
# [batch_size x num_train_examples x nFeat] * [batch_size x nFeat x nKbase]
AttentionCoeficients = self.scale_att * torch.bmm(Qe, wkeys)
AttentionCoeficients = F.softmax(
AttentionCoeficients.view(batch_size*num_train_examples, nKbase))
AttentionCoeficients = AttentionCoeficients.view(
batch_size, num_train_examples, nKbase)
# batch matrix multiplications: weight_novel = AttentionCoeficients * weight_base ==>
# [batch_size x num_train_examples x num_features] =
# [batch_size x num_train_examples x nKbase] * [batch_size x nKbase x num_features]
weight_novel = torch.bmm(AttentionCoeficients, weight_base)
# batch matrix multiplications: weight_novel = labels_train_transposed * weight_novel ==>
# [batch_size x nKnovel x num_features] =
# [batch_size x nKnovel x num_train_examples] * [batch_size x num_train_examples x num_features]
weight_novel = torch.bmm(labels_train_transposed, weight_novel)
weight_novel = weight_novel.div(
labels_train_transposed.sum(dim=2, keepdim=True).expand_as(weight_novel))
return weight_novel
使用的数据集
-
Mini-ImageNet
-
Bharath and Girshick [B. Hariharan and R. Girshick. Low-shot visual recognition by shrinking and hallucinating features. arXiv preprint arXiv:1606.02819, 2016.]
结论
在Mini-Image上测评,1-shot能够达到56.2%的准确率,5-shot能够达到73.0%的准确率。与此同时没有损失任何在基类上的准确率