Matching Networks for One Shot Learning
Abstract
研究领域: One Shot Learning(小样本学习)从少量样本中快速学习,是传统监督学习和Deep Learning无法解决的问题,该研究领域被称为小样本学习。
创新:
以下两种方法结合:
- metric learning目前,小样本学习的主流方法
- external memories以前小样本学习的主流方法
数据集: Omniglot & ImageNet
1. Introduction
- 人类可以从少量样本中学习新的概念。比如:一个小朋友看到邻居家的新玩具一次,下次跟妈妈去商场的时候马上就能从货架上认出它来。
- 现在Deep Learning仍然需要大数据的驱动。
- 一些non-parametric model可以快速学习新样本,比如KNN。本文要融合parametric model(即DL)和non-parametric model。DL中的样本是用完即弃的,而KNN中的样本会被保存。
- 本文的两个创新点:提出Matching Net(Section 2.1)& 新的训练测试方式(Section 2.2)
- 本文还为在Omniglot & ImageNet上的One Shot Learning实验设置了benchmark。
2. Model
2.1 Model Architecture
- 在网络上加external memories。
- external memories有很多种。在seq2seq中,external memories用于对 P ( B ∣ A ) , w h e r e A a n d B c a n b e a s e q u e n c e P(B|A),\ where\ A\ and\ B\ can\ be\ a\ sequence P(B∣A), where A and B can be a sequence的建模。在本文的Matching Net中,也用这种方式,只不过这里的 A , B A,\ B A, B是一个set。如上图所示,网络的输入是有多个图片组成的set。
- 数学建模部分。这里比较复杂,我会讲的详细一点。
从上图可以看到,左边4个图片形成一组,称为support set;右下1个单身狗,称为test example。全部5个图片称为1个task。
该模型用函数可表示为 p r e d i c t i o n = f ( s u p p o r t _ s e t , t e s t _ e x a m p l e ) prediction = f(support\_set,\ test\_example) prediction=f(support_set, test_example),即模型有两个输入。该模型用概率可表示为 P ( y ^ ∣ x ^ , S ) P(\hat y|\hat x, S) P(y^∣x^,S), 其中 S = { ( x i , y i ) } i = 1 k S = \{(x_i, y_i)\}_{i=1}^k S={(xi,yi)}i=1k,k表示support set中样本的个数。上图support set有4个图片,k=4。
Matching Net作者把该模型表示为:
y
^
=
∑
i
=
1
k
a
(
x
^
,
x
i
)
y
i
\hat y = \sum_{i=1}^k a(\hat x, x_i) y_i
y^=i=1∑ka(x^,xi)yi
预测值 y ^ \hat y y^被看做是support set中样本的labels的线性组合,组合的权重是test example和support set中1个样本的关系—— a ( x ^ , x i ) a(\hat x, x_i) a(x^,xi)。
- 将 a ( x ^ , x i ) a(\hat x, x_i) a(x^,xi)作为一个核函数,则该模型可近似为:Deep Learning做嵌入层,KDE做分类层。
- 将 a ( x ^ , x i ) a(\hat x, x_i) a(x^,xi)作为一个01函数,则该模型可金思维:Deep Learning做嵌入层,KNN做分类层。
2.1.1 The Attention Kernel
本文赋予 a ( x ^ , x i ) a(\hat x, x_i) a(x^,xi)新的形式——将它看做attention kernel。此时,模型的预测结果就是support set中attention最多的图片的label。
常见的attention kernel是cosine距离上的softmax:
a
(
x
^
,
x
i
)
=
e
c
(
f
(
x
^
)
,
g
(
x
i
)
)
∑
j
=
1
k
e
c
(
f
(
x
^
)
,
g
(
x
j
)
)
a(\hat x, x_i) = \frac {e^{c(f(\hat x), g(x_i))}}{\sum_{j=1}^k e^{c(f(\hat x), g(x_j))}}
a(x^,xi)=∑j=1kec(f(x^),g(xj))ec(f(x^),g(xi)),其中
f
,
g
f, g
f,g是两个嵌入函数(可由神经网络实现,如:VGG or Inception)。
2.1.2 Full Context Embeddings
嵌入向量 e m b _ x i = g ( x i ) ← g ( x i , S ) emb\_x_i = g(x_i) \leftarrow g(x_i, S) emb_xi=g(xi)←g(xi,S),嵌入函数的输出同时由对应的 x i x_i xi和整个support set有关。support set是每次随机选取的,嵌入函数同时考虑support set和 x i x_i xi可以消除随机选择造成的差异性。类似机器翻译中word和context的关系, S S S可以看做是 x i x_i xi的context,所以本文在嵌入函数中用到了LSTM。
对text example的嵌入函数为
f
f
f:
f
(
x
^
,
S
)
=
attLSTM
(
f
′
(
x
^
)
,
g
(
S
)
,
K
)
f(\hat x, S) = \textbf{attLSTM}(f'(\hat x), g(S), K)
f(x^,S)=attLSTM(f′(x^),g(S),K),其中
f
′
(
x
^
)
f'(\hat x)
f′(x^)是CNN嵌入层的输出,可以是VGG或Inception,
g
(
S
)
g(S)
g(S)是support set中样本的嵌入函数输出,K是LSTM层的timesteps,等于support set的图片个数。
详解full context embedding:
The Fully Conditional Embedding f
h ^ k , c k = L S T M ( f ′ ( x ^ ) , [ h k − 1 , r k − 1 ] , c k − 1 ) \hat h_k, c_k = LSTM(f'(\hat x), [h_{k-1}, r_{k-1}], c_{k-1}) h^k,ck=LSTM(f′(x^),[hk−1,rk−1],ck−1)
h k = h ^ k + f ′ ( x ^ ) h_k = \hat h_k + f'(\hat x) hk=h^k+f′(x^)
r k − 1 = ∑ i = 1 ∣ S ∣ a ( h k − 1 , g ( x i ) ) g ( x i ) r_{k-1} = \sum_{i=1}^{|S|} a(h_{k-1}, g(x_i))g(x_i) rk−1=i=1∑∣S∣a(hk−1,g(xi))g(xi)
a ( h k − 1 , g ( x i ) ) = s o f t m a x ( h k − 1 T g ( x i ) ) a(h_{k-1}, g(x_i)) = softmax(h_{k-1}^Tg(x_i)) a(hk−1,g(xi))=softmax(hk−1Tg(xi))
The Fully Conditional Embedding g
support set中的 x i x_i xi在经过多层卷积网路后,在经过一层bidirectional LSTM。
2. 2Training Strategy
- 一个batch包括多个task;
- 一个task包括一个support set和一个test example;
- 一个support set包括多个sample(image & label);
- support set中有且只有一个样本与test example同类。
Related Work
Memory Augumented Neural Networks attention机制
Metric Learning 比较学习