匹配网络 Matching Network
匹配网络其实就是引入注意力机制,通过对 embedding 后的特征计算注意力,利用注意力得分进行分析:

首先也是对支持集和查询集进行 embedding,然后用查询集样本对每个支持集样本计算注意力:
a
(
x
^
,
x
i
)
=
e
c
(
f
(
x
^
)
,
g
(
x
i
)
)
/
∑
j
=
1
k
e
c
(
f
(
x
^
)
,
g
(
x
j
)
)
a\left(\hat{x}, x_{i}\right)=e^{c\left(f(\hat{x}), g\left(x_{i}\right)\right)} / \sum_{j=1}^{k} e^{c\left(f(\hat{x}), g\left(x_{j}\right)\right)}
a(x^,xi)=ec(f(x^),g(xi))/j=1∑kec(f(x^),g(xj))
其中:
- f 和 g是我们选择的合适的神经网络,一般 f = g,用于输入的 embedding
- x i x_i xi 是支持集, x ^ \hat x x^ 是查询集
- c 是余弦距离
计算了注意力之后,就分析查询集的样本:
P
(
y
^
∣
x
^
,
S
)
=
∑
i
=
1
k
a
(
x
^
,
x
i
)
y
i
P(\hat{y} \mid \hat{x}, S)=\sum_{i=1}^{k} a\left(\hat{x}, x_{i}\right) y_{i}
P(y^∣x^,S)=i=1∑ka(x^,xi)yi
其中:
- y i y_i yi 是每个类别的标签,其实就是把每个类别根据注意力得分进行线性加权
- P 是计算出对应类别的概率
最后的训练目标为:
θ
=
arg
max
θ
E
L
∼
T
[
E
S
∼
L
,
B
∼
L
[
∑
(
x
,
y
)
∈
B
log
P
θ
(
y
∣
x
,
S
)
]
]
\theta=\arg \max _{\theta} E_{L \sim T}\left[E_{S \sim L, B \sim L}\left[\sum_{(x, y) \in B} \log P_{\theta}(y \mid x, S)\right]\right]
θ=argθmaxEL∼T⎣⎡ES∼L,B∼L⎣⎡(x,y)∈B∑logPθ(y∣x,S)⎦⎤⎦⎤
个人总结:
总的来说,匹配网络把整个分析的过程都简化到注意力计算过程中,如果某个类别的注意力得分比较高,其实就意味着测试样本属于这个类别的可能性比较大,所以模型的训练重点就回到最初的 embedding 中。
Vinyals O, Blundell C, Lillicrap T, et al. Matching networks for one shot learning[J]. Advances in neural information processing systems, NIPS 2016, 29: 3630-3638.
元学习系列(四):Matching Network(匹配网络)
1123

被折叠的 条评论
为什么被折叠?



