Meta-Transfer Learning for Few-Shot Learning 论文地址 代码
写在前面
这是cvpr19年又一篇用meta-learning 做 few-shot learning的文章, 跟MAML不同的地方, 就是它不做全部参数的finetune了, 需要更新的参数变少, 使得网络不容易过拟合当前的task, 也使得网路收敛的快, 同时还使用了hard sample mining的方式来提高网络的精度.
Motivation
MAML这种meta-learning 做 few-shot 的方法存在两个主要的问题:
- 这类方法一般都需要很多相似的任务来实现meta-train;
- 之前都是用浅的网络来做few-shot这种task, 因为网络一深容易过拟合, 所以很难去使用一些比较厉害的深网络.
Contribution
- 这篇文章提出了一个meta-transfer learning (MTL) 方法来结合迁移学习和meta-learning , 使得预训练的深度网络迁移到few-shot 的任务上, 网路收敛快并且避免了过拟合;
- 提出了一个Hard Task (HT) meta-batch 训练策略, 使得网络在艰难样本下性能更好.
Algorithm
总的流程图如下:
1. DNN training on large-scale data
这边网络的预训练就是与一般的分类网络训练方式一样, 并不是用meta-train的方式来训的, 就比如如果原始的数据集中有64个类, 则训练出来的分类器也是输出64个分类score的. 这篇文章把网络分成了特征提取器
Θ
\Theta
Θ 和base-learner
θ
\theta
θ (其实就是分类器). 网络训练的时候随机初始化
Θ
\Theta
Θ 和
θ
\theta
θ , 其更新方式就是一般的梯度下降的方式:
[
Θ
;
θ
]
=
:
[
Θ
;
θ
]
−
α
∇
L
D
(
[
Θ
;
θ
]
)
,
[\Theta; \theta] = : [\Theta; \theta]-\alpha \nabla L_D([\Theta; \theta]),
[Θ;θ]=:[Θ;θ]−α∇LD([Θ;θ]),
L
D
(
[
Θ
;
θ
]
)
=
1
∣
D
∣
∑
(
x
,
y
)
∈
D
l
(
f
[
Θ
;
θ
]
(
x
)
,
y
)
L_D([\Theta; \theta]) = \frac{1}{|D|}\sum_{(x,y) \in D} l(f_{[\Theta; \theta]}(x),y)
LD([Θ;θ])=∣D∣1(x,y)∈D∑l(f[Θ;θ](x),y)
这里的
l
l
l 指的是一些距离度量函数, 可以是交叉熵损失,
α
\alpha
α 是学习率. 当训练完后,
θ
\theta
θ 会被砍掉, 然后接新的分类器, 因为few-shot 的任务sample的类别数和总的训练数据集肯定是不一样的.
2. Meta-transfer learning (MTL)
这块介绍了他们提出的scaling 和 Shifting (SS) Φ S { 1 , 2 } \Phi_{S_{\{1,2\}}} ΦS{1,2} 模块, 就是他们在迁移的时候不是直接更新网络参数, 而是在原始的weight和bias上做了一些操作, 这样不仅减少了参数的学习, 同时也保留了预训练时学到的general的信息不被破坏, 减少了过拟合的概率. 细节如下:
对于一个task
T
=
{
T
(
t
r
)
,
T
(
t
e
)
}
T = \{T^{(tr)}, T^{(te)} \}
T={T(tr),T(te)},
T
(
t
r
)
T^{(tr)}
T(tr)的损失就是用来更新当前的base-learner (classifier)
θ
′
\theta'
θ′ :
θ
′
←
θ
−
β
∇
θ
L
T
(
t
r
)
(
[
Θ
;
θ
]
,
Φ
S
{
1
,
2
}
)
\theta' \leftarrow \theta- \beta \nabla_{\theta} L_{T^{(tr)}}([\Theta; \theta] , \Phi_{S_{\{1,2\}}})
θ′←θ−β∇θLT(tr)([Θ;θ],ΦS{1,2})
这里的
θ
\theta
θ 和上一节的
θ
\theta
θ 不一样, 这里的
θ
\theta
θ 是指只有某几个类的分类器参数; 然后跟上一节不一样的地方就是, 这里不更新
Θ
\Theta
Θ,
Θ
\Theta
Θ 作为特征提取器, 在预训练完之后就一直不会变.
然后
T
(
t
e
)
T^{(te)}
T(te) 的损失是用来更新
Φ
S
{
1
,
2
}
\Phi_{S_{\{1,2\}}}
ΦS{1,2} , 这里
Φ
S
1
\Phi_{S_1}
ΦS1初始化为全1,
Φ
S
2
\Phi_{S_2}
ΦS2 初始化为0. 更新过程如下:
Φ
S
i
=
:
Φ
S
i
−
γ
∇
Φ
S
i
L
T
(
t
e
)
(
[
Θ
;
θ
′
]
,
Φ
S
{
1
,
2
}
)
\Phi_{S_i} = : \Phi_{S_i} - \gamma \nabla_{ \Phi_{S_i}} L_{T^{(te)}}([\Theta; \theta'] , \Phi_{S_{\{1,2\}}})
ΦSi=:ΦSi−γ∇ΦSiLT(te)([Θ;θ′],ΦS{1,2})
θ
=
:
θ
−
γ
∇
Φ
S
i
L
T
(
t
e
)
(
[
Θ
;
θ
′
]
,
Φ
S
{
1
,
2
}
)
\theta = :\theta - \gamma \nabla_{ \Phi_{S_i}} L_{T^{(te)}}([\Theta; \theta'] , \Phi_{S_{\{1,2\}}})
θ=:θ−γ∇ΦSiLT(te)([Θ;θ′],ΦS{1,2})
下面介绍如何将
Φ
S
{
1
,
2
}
\Phi_{S_{\{1,2\}}}
ΦS{1,2} 应用到网络中,
Θ
\Theta
Θ 中的所有参数都用
W
,
b
W,b
W,b表示, 所以对于一个输入
X
X
X, 经过SS后提取的特征:
S
S
(
X
;
W
,
b
;
Φ
S
{
1
,
2
}
)
=
(
W
⊙
Φ
S
1
)
X
+
(
b
+
Φ
S
2
)
SS(X; W, b;\Phi_{S_{\{1,2\}}}) = (W \odot \Phi_{S_1})X+(b+\Phi_{S_2})
SS(X;W,b;ΦS{1,2})=(W⊙ΦS1)X+(b+ΦS2)
如下图:
SS有以下三个优点:
- 利用了深度的DNN提供了一个strong的初始化, 使得MTL可以很快收敛;
- 没有改变DNN的权重, 避免破坏原始网络学到的general的信息;
- SS是很轻量的, 减少了过拟合的概率.
3 Hard task (HT) meta-batch
这边就是对于所有的task , 先用当前的网络去测试一下, 将每个task中分类acc最低的类别记录下来, 在所有的train完之后, 再用这些类别的sample组合成难的task来训练, 这是参考了课程学习的概念, 使得网络在逐步变难的样本中逐步增加网络的分类性能.
本文的算法流程可以从以下两张图和直观的看出:
这里我比较有疑问的地方就是 θ ′ \theta' θ′ 和 θ \theta θ 之间的关系, 若是 θ ′ \theta' θ′ 的更新都是根据 θ \theta θ, 那算法2中的4, 有什么意义吗, 若是 θ \theta θ 不变, 这边迭代就没有意义了. 可能需要看代码才知道了.