基于剧集的原型生成网络 for ZSL
1 动机
将训练集分成两个不想交的子集,模仿小样本里的support set和query set。主要模仿的点是,通过query set的loss更新模型来提升模型的泛化能力。这种泛化能力,在小样本中是学习几个标注样本就能快速分类的能力,在零样本中是学习已见类后,在未见类上能够进行快速分类。
2 方法
2.1 Prototype Generating Network
L
V
→
A
=
∑
i
∣
∣
F
(
x
i
)
−
a
i
∣
∣
2
2
.
(1)
L_{\mathcal{V→A}} =\sum_i || F(\textbf x_i) − \textbf a_i||_2^2. \tag{1}
LV→A=i∑∣∣F(xi)−ai∣∣22.(1)
L A → V = ∑ i ∣ ∣ G ( a i ) − x i ∣ ∣ 2 2 . (2) L_{\mathcal{A→V}} =\sum_i || G(\textbf a_i) − \textbf x_i||_2^2. \tag{2} LA→V=i∑∣∣G(ai)−xi∣∣22.(2)
KaTeX parse error: Unexpected end of input in a macro argument, expected '}' at position 75: …b{E} [D(\tilde \̲t̲e̲x̲t̲b̲f̲ ̲ ̲x, \tilde \text…
KaTeX parse error: Unexpected end of input in a macro argument, expected '}' at position 8: \tilde \̲t̲e̲x̲t̲b̲f̲ ̲ ̲x = F( \textbf …
KaTeX parse error: Unexpected end of input in a macro argument, expected '}' at position 8: \tilde \̲t̲e̲x̲t̲b̲f̲ ̲ ̲x = G( \textbf …
KaTeX parse error: Unexpected end of input in a macro argument, expected '}' at position 6: \hat \̲t̲e̲x̲t̲b̲f̲ ̲x = \tau \textb…
KaTeX parse error: Unexpected end of input in a macro argument, expected '}' at position 6: \hat \̲t̲e̲x̲t̲b̲f̲ ̲a = \tau \textb…
p i V ( x ) = e x p ( x T G ( a i ) ) ∑ j e x p ( x T G ( a j ) ) . (4) p_i^{\mathcal V} (\textbf x) = \frac {exp(\textbf x^TG(\textbf a_i))} {\sum_j exp(\textbf x^TG(\textbf a_j))}. \tag{4} piV(x)=∑jexp(xTG(aj))exp(xTG(ai)).(4)
p i S ( x ) = e x p ( F ( x ) T a i ) ∑ j e x p ( F ( x ) T a j ) . (5) p_i^{\mathcal S} (\textbf x) = \frac {exp(F(\textbf x)^T\textbf a_i)} {\sum_j exp(F(\textbf x)^T\textbf a_j)}. \tag{5} piS(x)=∑jexp(F(x)Taj)exp(F(x)Tai).(5)
L M C E = − ∑ x log p i V ( x ) − ∑ x log p i S ( x ) . (6) \mathcal L_{MCE} = -\sum_{\textbf x} \log p_i^{\mathcal V} (\textbf x) -\sum_{\textbf x} \log p_i^{\mathcal S} (\textbf x). \tag{6} LMCE=−x∑logpiV(x)−x∑logpiS(x).(6)
min G max D L W G A N + α L V → A + β L A → V + γ L M C E . (7) \min_G \max_D L_{WGAN} + \alpha L_{\mathcal{V→A}} + \beta L_{\mathcal{A→V}} + \gamma L_{MCE}. \tag{7} GminDmaxLWGAN+αLV→A+βLA→V+γLMCE.(7)
2.2 Refining Model
输入为
x
t
\textbf x_t
xt,输出预测类别标签为:
y
^
t
=
arg
min
k
d
(
x
t
,
G
(
a
k
)
)
.
(8)
\hat y_t = \arg \min_k d( \textbf x_t, G(\textbf a_k) ). \tag{8}
y^t=argkmind(xt,G(ak)).(8)
输入为 x t \textbf x_t xt,输出的预测概率类别为为:
p G ( y = k ∣ x t ) = e x p ( − d ( x t , G ( a k ) ) ) ∑ k ′ e x p ( − d ( x t , G ( a k ′ ) ) ) . (9) p_G(y=k|\textbf x_t) = \frac {exp(-d(\textbf x_t, G(\textbf a_k)))} {\sum_{k'} exp(-d(\textbf x_t, G(\textbf a_{k'})))} . \tag{9} pG(y=k∣xt)=∑k′exp(−d(xt,G(ak′)))exp(−d(xt,G(ak))).(9)
3 实验
3.1 对比实验
这不是直推式零样本。算高的吧。
3.2 episode的效果
分别表示训练集中留出0,5,10,15个类作为query set的结果。0表示不用episode方式。
3.3 消融实验
√表示是1,空表示是0。
4 总结
很棒。归纳式。