使用二阶梯度作正则项交叉训练参数
在上周五讨论时关于交叉训练"语义概念参数"和"视觉概念参数"时我们说到了导致正确率底下的两个缺陷:
- 训练样本少:这个即将解决,因为我们上次讨论将进行关于"指令"的概念训练,这比需要逻辑推理的任务更加简单,而也可以生成更多的样本;
- 训练方式的问题。现在我引入一个有效的正则项:这是在凸优化的数学原理上来优化,这是本文要谈的;
值得注意的是,这个方法(用二阶梯度作正则项)不止对我们当前这个任务适用,我更感觉这是一个通用的适用于损失是凸函数的方法;
再谈模型的损失
注意我们上次说到,模型的预测是one-hot形式的空间概念指代:[上下,左右,左上右下… …],那么最后一层适用softmax交叉熵损失来做,那么损失是:
L ( A s ) = − y ^ ⊙ l o g ( f s o f t m a x ( A s ⊙ C ∗ T ) ) \mathcal{L}(A_s) = -\hat{y} \odot log(f_{softmax}(A_s \odot C^{*T}) ) L(As)=−y^⊙log(fsoftmax(As⊙C∗T))
另一边,我们说到, A s A_s As也在另一个网络中发挥权重参数的作用,我们令来自那个网络的损失为 L ^ ( A s ) \hat{\mathcal{L}}(A_s) L^(As);下面我们将证明,这是一个凸函数。
softmax是凸函数
这个证明也可以跳过,只需要记住softmax是凸函数这个结论也行,证明如下,softmax的交叉熵损失是:
L ( w 1 , w 2 , ⋯   , w k ) = − 1 m [ ∑ i = 1 m ∑ j = 1 k 1 { y ( i ) = j } log e w j T x ( i ) ∑ l = 1 k e w l T x ( i ) ] \mathcal{L}(w_1, w_2,\cdots, w_k)=-\frac1m\left[\sum_{i=1}^m\sum_{j=1}^k 1\{y^{(i)}=j\}\log\frac{e^{w^T_jx^{(i)}}}{\sum_{l=1}^ke^{w^T_lx^{(i)}}}\right] L(w1,w2,⋯,wk)=−m1[i=1∑mj=1∑k1{y(i)=j}log∑l=1kewlTx(i)ewjTx(i)]
现在令 :
a
j
=
e
w
j
T
x
∑
l
=
1
k
e
w
l
T
x
a_j=\frac{e^{w^T_jx}}{\sum_{l=1}^ke^{w^T_lx}}
aj=∑l=1kewlTxewjTx
分情况,当
n
≠
j
n\neq j
n̸=j:
∇
w
n
a
j
=
−
e
w
j
T
x
e
w
n
T
x
(
∑
l
=
1
k
e
w
l
T
x
)
2
x
=
−
a
j
a
n
x
\nabla_{w_n} a_j=-\frac{e^{w_j^Tx}e^{w_n^Tx}}{\left(\sum_{l=1}^ke^{w_l^Tx}\right)^2}x=-a_ja_nx
∇wnaj=−(∑l=1kewlTx)2ewjTxewnTxx=−ajanx
当
n
=
j
n= j
n=j:
∇
w
j
a
j
=
(
e
w
j
T
x
∑
l
=
1
k
e
w
l
T
x
−
e
w
j
T
x
∑
l
=
1
k
e
w
l
T
x
e
w
j
T
x
∑
l
=
1
k
e
w
l
T
x
)
x
=
a
j
(
1
−
a
j
)
x
\nabla_{w_j} a_j=\left(\frac{e^{w_j^Tx}}{\sum_{l=1}^ke^{w_l^Tx}}-\frac{e^{w_j^Tx}}{\sum_{l=1}^ke^{w_l^Tx}}\frac{e^{w_j^Tx}}{\sum_{l=1}^ke^{w_l^Tx}}\right)x=a_j(1-a_j)x
∇wjaj=(∑l=1kewlTxewjTx−∑l=1kewlTxewjTx∑l=1kewlTxewjTx)x=aj(1−aj)x
所以有:
∇
w
n
C
=
−
1
m
∑
i
=
1
m
(
∑
j
≠
n
−
1
{
y
(
i
)
=
j
}
a
j
a
n
/
a
j
x
(
i
)
+
1
{
y
(
i
)
=
n
}
a
n
(
1
−
a
n
)
/
a
n
x
(
i
)
)
=
−
1
m
∑
i
=
1
m
[
x
(
i
)
(
1
{
y
(
i
)
=
n
}
−
a
n
)
]
\nabla_{w_n}C=-\frac1m\sum_{i=1}^m\left(\sum_{j\neq n}-1\{y^{(i)}=j\}a_ja_n/a_jx^{(i)}+1\{y^{(i)}=n\}a_n(1-a_n)/a_nx^{(i)}\right) =-\frac1m\sum_{i=1}^m\left[x^{(i)}(1\{y^{(i)}=n\}-a_n)\right]
∇wnC=−m1i=1∑m⎝⎛j̸=n∑−1{y(i)=j}ajan/ajx(i)+1{y(i)=n}an(1−an)/anx(i)⎠⎞=−m1i=1∑m[x(i)(1{y(i)=n}−an)]
注意
1
{
y
(
i
)
=
j
}
1\{y^{(i)}=j\}
1{y(i)=j}只有当
y
(
i
)
=
j
y^{(i)}=j
y(i)=j时为一,那么下式为半正定矩阵,因而对于softmax而言,交叉熵为凸函数:
∇
w
n
2
C
=
−
1
m
∑
i
=
1
m
∇
w
n
[
x
(
i
)
(
1
{
y
(
i
)
=
n
}
−
a
n
)
]
=
1
m
∑
i
=
1
m
a
n
(
1
−
a
n
)
x
(
i
)
x
(
i
)
T
\nabla_{w_n}^2C=-\frac1m\sum_{i=1}^m\nabla_{w_n}\left[x^{(i)}(1\{y^{(i)}=n\}-a_n)\right]=\frac1m\sum_{i=1}^ma_n(1-a_n)x^{(i)}x^{(i)T}
∇wn2C=−m1i=1∑m∇wn[x(i)(1{y(i)=n}−an)]=m1i=1∑man(1−an)x(i)x(i)T
L ^ ( A + α Δ A ) \hat{\mathcal{L}}(A+ \alpha \Delta A) L^(A+αΔA)的一个上界的证明
注意参数
A
A
A的更新方式采取最简单的SGD:
A
t
+
1
←
A
t
+
α
∇
A
L
^
A^{t+1} \leftarrow A^t+\alpha \nabla_A \hat{\mathcal{L}}
At+1←At+α∇AL^
证明:当
∇
A
2
L
^
≤
M
I
\nabla^2_A \hat{\mathcal{L}} \leq MI
∇A2L^≤MI时:
L
^
(
A
+
α
Δ
A
)
≤
L
^
(
A
)
+
γ
∣
∣
∇
A
L
^
∣
∣
2
\hat{\mathcal{L}}(A+ \alpha \Delta A) \leq \hat{\mathcal{L}}(A) + \gamma||\nabla_A \hat{\mathcal{L}}||^2
L^(A+αΔA)≤L^(A)+γ∣∣∇AL^∣∣2
proof:
首先易知
−
∇
A
L
^
(
A
)
=
Δ
A
-\nabla_A \hat{\mathcal{L}}(A) = \Delta A
−∇AL^(A)=ΔA;现在我们对
L
^
(
A
+
α
Δ
A
)
\hat{\mathcal{L}}(A+ \alpha \Delta A)
L^(A+αΔA)作Taylor展开:
L
^
(
A
+
α
Δ
A
)
=
L
^
(
A
)
+
α
∇
A
L
^
(
A
)
⊙
Δ
A
+
∇
A
2
L
^
∣
∣
Δ
A
∣
∣
2
α
2
/
2
\hat{\mathcal{L}}(A+ \alpha \Delta A) = \hat{\mathcal{L}}(A)+ \alpha \nabla_A \hat{\mathcal{L}}(A) \odot \Delta A +\nabla^2_A \hat{\mathcal{L}} ||\Delta A||^2 \alpha^2 /2
L^(A+αΔA)=L^(A)+α∇AL^(A)⊙ΔA+∇A2L^∣∣ΔA∣∣2α2/2
≤
L
^
(
A
)
+
α
∇
A
L
^
⊙
(
−
∇
A
L
^
)
+
M
∣
∣
Δ
A
∣
∣
2
α
2
/
2
\le \hat{\mathcal{L}}(A)+ \alpha \nabla_A \hat{\mathcal{L}} \odot (-\nabla_A \hat{\mathcal{L}}) +M||\Delta A||^2 \alpha^2 /2
≤L^(A)+α∇AL^⊙(−∇AL^)+M∣∣ΔA∣∣2α2/2
=
L
^
(
A
)
+
(
α
2
M
/
2
−
α
)
∣
∣
∇
A
L
^
∣
∣
2
= \hat{\mathcal{L}}(A)+(\alpha^2M /2 - \alpha )||\nabla_A \hat{\mathcal{L}}||^2
=L^(A)+(α2M/2−α)∣∣∇AL^∣∣2
现在令
γ
=
α
2
M
/
2
−
α
≤
0
\gamma = \alpha^2M /2 - \alpha \le 0
γ=α2M/2−α≤0即可满足:
L
^
(
A
+
α
Δ
A
)
≤
L
^
(
A
+
α
Δ
A
)
+
(
α
−
α
2
M
/
2
)
∣
∣
∇
A
L
^
∣
∣
2
≤
L
^
(
A
)
\hat{\mathcal{L}}(A+ \alpha \Delta A) \le \hat{\mathcal{L}}(A+ \alpha \Delta A) +(\alpha-\alpha^2M /2 )||\nabla_A \hat{\mathcal{L}}||^2 \le \hat{\mathcal{L}}(A)
L^(A+αΔA)≤L^(A+αΔA)+(α−α2M/2)∣∣∇AL^∣∣2≤L^(A)
得证 □ \Box □;
最终的正则项
现在我们已知,只需要 ∇ A 2 L ^ ≤ M I \nabla^2_A \hat{\mathcal{L}} \leq MI ∇A2L^≤MI,就可以使得每一步梯度更新保证减小目标函数;我们的正则项是 λ M ∣ ∣ I ∣ ∣ 2 ∇ A 2 L ^ \lambda \frac{M||I||^2}{\nabla^2_A \hat{\mathcal{L}}} λ∇A2L^M∣∣I∣∣2,最终的损失是:
L ( A s ) = − y ^ ⊙ l o g ( f s o f t m a x ( A s ⊙ C ∗ T ) ) + λ M ∣ ∣ I ∣ ∣ 2 ∇ A s 2 L ( A s ) ^ \mathcal{L}(A_s) = -\hat{y} \odot log(f_{softmax}(A_s \odot C^{*T}) )+\lambda \frac{M||I||^2}{\nabla^2_{A_s} \hat{\mathcal{L(A_s)}}} L(As)=−y^⊙log(fsoftmax(As⊙C∗T))+λ∇As2L(As)^M∣∣I∣∣2
还没有验证实验效果,本周马上会使用这个损失来训练概念学习模型。