之前在学蒸馏的时候接触了gumbel-softmax,顺势了解了一下重参数技巧,还是很有意思的一个东西
引入
重参数技巧主要是尝试对这样形式的一个东西求梯度
L
θ
=
E
z
∼
p
θ
(
z
)
[
f
θ
(
z
)
]
(
1
)
\large L_{\theta} = E_{z\sim p_{\theta}(z)}[f_{\theta}(z)] \quad \quad(1)
Lθ=Ez∼pθ(z)[fθ(z)](1)
其中
z
∼
p
θ
(
z
)
z\sim p_{\theta}(z)
z∼pθ(z)表示随机变量
z
z
z服从概率密度函数
p
θ
(
z
)
p_{\theta}(z)
pθ(z),显然这个密度函数是跟模型参数
θ
\theta
θ有关的;
f
θ
(
z
)
f_{\theta}(z)
fθ(z)一般可以表示模型某一层关于变量
z
z
z的输出,显然它也跟模型参数
θ
\theta
θ有关
不妨先来想想这个式子要如何处理。一个非常naive的思路:采样估计。但是如果直接采样的话,每次采样我们只能获得 ∇ θ f θ ( z ) \nabla_\theta f_\theta(z) ∇θfθ(z),而不同样本之间的信息是无法共用的,我们也就无从得到 ∇ θ L θ \nabla_\theta L_\theta ∇θLθ。所以我们想想看,有没有什么好的处理方法,能在估计出 ( 1 ) (1) (1)式的同时还能保留梯度信息
不妨先来做一个简化,我们先假设
p
θ
(
z
)
p_{\theta}(z)
pθ(z)是一个跟
θ
\theta
θ无关的概率密度函数,简记为
p
(
z
)
p(z)
p(z),我们很快注意到现在是可以采样估计梯度了:
∇
θ
L
θ
=
∇
θ
E
z
∼
p
(
z
)
[
f
θ
(
z
)
]
=
∇
θ
[
∫
z
p
(
z
)
f
θ
(
z
)
d
z
]
=
∫
z
p
(
z
)
∇
θ
f
θ
(
z
)
d
z
=
E
z
∼
p
(
z
)
[
∇
θ
f
θ
(
z
)
]
\large \nabla_{\theta} L_{\theta} = \nabla_{\theta}E_{z\sim p(z)}[f_{\theta}(z)] = \nabla_{\theta}[\int_z p(z)f_{\theta}(z)dz]\\ =\int_z p(z)\nabla_{\theta}f_{\theta}(z)dz\\ =E_{z\sim p(z)}[\nabla_{\theta}f_{\theta}(z)]
∇θLθ=∇θEz∼p(z)[fθ(z)]=∇θ[∫zp(z)fθ(z)dz]=∫zp(z)∇θfθ(z)dz=Ez∼p(z)[∇θfθ(z)]
从而
∇
θ
L
θ
≈
1
n
∑
i
=
1
n
∇
θ
f
θ
(
z
i
)
,
z
i
∼
p
(
z
)
\large \nabla_{\theta} L_{\theta} \approx \frac{1}{n}\sum_{i=1}^{n} \nabla_{\theta}f_{\theta}(z_i),z_i\sim p(z)
∇θLθ≈n1i=1∑n∇θfθ(zi),zi∼p(z)
这是因为求梯度的操作成功转移到了
f
θ
(
z
)
f_\theta(z)
fθ(z)上面
上述过程可以用一句话来总结:期望的梯度等于梯度的期望
那我们回到
p
θ
(
z
)
p_\theta(z)
pθ(z),并尝试类似的步骤:
∇
θ
L
θ
=
∇
θ
E
z
∼
p
θ
(
z
)
[
f
θ
(
z
)
]
=
∇
θ
[
∫
z
p
θ
(
z
)
f
θ
(
z
)
d
z
]
=
∫
z
p
θ
(
z
)
∇
θ
f
θ
(
z
)
d
z
+
∫
z
∇
θ
p
θ
(
z
)
f
θ
(
z
)
d
z
=
E
z
∼
p
θ
(
z
)
[
∇
θ
f
θ
(
z
)
]
+
∫
z
∇
θ
p
θ
(
z
)
f
θ
(
z
)
d
z
⏟
?
?
?
\large \nabla_{\theta} L_{\theta} = \nabla_{\theta}E_{z\sim p_\theta(z)}[f_{\theta}(z)] = \nabla_{\theta}[\int_z p_\theta(z)f_{\theta}(z)dz]\\ =\int_z p_\theta(z)\nabla_{\theta}f_{\theta}(z)dz+\int_z \nabla_{\theta}p_\theta(z)f_{\theta}(z)dz\\ =E_{z\sim p_\theta(z)}[\nabla_{\theta}f_{\theta}(z)]+\underbrace{\int_z \nabla_{\theta}p_\theta(z)f_{\theta}(z)dz}_{???}
∇θLθ=∇θEz∼pθ(z)[fθ(z)]=∇θ[∫zpθ(z)fθ(z)dz]=∫zpθ(z)∇θfθ(z)dz+∫z∇θpθ(z)fθ(z)dz=Ez∼pθ(z)[∇θfθ(z)]+???
∫z∇θpθ(z)fθ(z)dz
前面一块还是可以仿照之前的处理的,但是后者就显得比较诡异了,求梯度操作转移到
p
θ
(
z
)
p_\theta(z)
pθ(z)上面去,也就意味着我们无法将其整理成正常的关于某个东西的期望的形式。或许我们可以将
∇
θ
p
θ
(
z
)
\nabla_{\theta}p_\theta(z)
∇θpθ(z)求出来,但在大部分情况下这是不现实的。
此时就可以引入重参数技巧了
重参数
顾名思义,我们需要引入新的参数来处理上述问题:
考虑一个新的无参数分布
ϵ
∼
q
(
ϵ
)
\large \epsilon\sim{q(\epsilon)}
ϵ∼q(ϵ)
以及变换
z
=
g
θ
(
ϵ
)
\large z = g_\theta(\epsilon)
z=gθ(ϵ)
保证变换之后得到的
z
z
z服从
p
θ
p_\theta
pθ
那么对
(
1
)
(1)
(1)式求梯度可以变成:
∇
θ
L
θ
=
∇
θ
E
z
∼
p
θ
(
z
)
[
f
θ
(
z
)
]
=
E
ϵ
∼
q
(
ϵ
)
[
f
θ
(
g
θ
(
ϵ
)
)
]
(
a
)
=
E
ϵ
∼
q
(
ϵ
)
[
∇
θ
f
θ
(
g
θ
(
ϵ
)
)
]
(
b
)
\large \nabla_{\theta} L_{\theta} = \nabla_{\theta}E_{z\sim p_\theta(z)}[f_{\theta}(z)] \\ = E_{\epsilon\sim q(\epsilon)}[f_\theta(g_\theta(\epsilon))]\quad \quad (a)\\ =E_{\epsilon\sim q(\epsilon)}[\nabla_{\theta}f_\theta(g_\theta(\epsilon))]\ \ \ (b)
∇θLθ=∇θEz∼pθ(z)[fθ(z)]=Eϵ∼q(ϵ)[fθ(gθ(ϵ))](a)=Eϵ∼q(ϵ)[∇θfθ(gθ(ϵ))] (b)
从而
∇
θ
L
θ
≈
1
n
∑
i
=
1
n
∇
θ
f
θ
(
g
θ
(
ϵ
i
)
)
,
ϵ
i
∼
q
(
ϵ
)
\large \nabla_{\theta} L_{\theta} \approx \frac{1}{n}\sum_{i=1}^{n} \nabla_{\theta}f_\theta(g_\theta(\epsilon_i)),\epsilon_i\sim q(\epsilon)
∇θLθ≈n1i=1∑n∇θfθ(gθ(ϵi)),ϵi∼q(ϵ)
我们就成功实现了在采样的同时保持了梯度
注意,在这个过程中最重要的一步转化就是:
L
θ
=
E
ϵ
∼
q
(
ϵ
)
[
f
θ
(
g
θ
(
ϵ
)
)
]
\large L_\theta = E_{\epsilon\sim q(\epsilon)}[f_\theta(g_\theta(\epsilon))]
Lθ=Eϵ∼q(ϵ)[fθ(gθ(ϵ))]
它将随机性从参数
θ
\theta
θ转移到了内部无参数的
ϵ
\epsilon
ϵ上面,从而可以利用我们之前讨论过的对无参数分布(或者说无可变参数)而言成立的“期望的梯度等于梯度的期望”这一性质来处理
例子
不妨就取
p
θ
(
z
)
p_\theta(z)
pθ(z)是一个正态分布,即
p
θ
(
z
)
=
N
(
μ
θ
,
σ
θ
2
)
\large p_\theta(z) = N(\mu_\theta,\sigma_\theta^2)
pθ(z)=N(μθ,σθ2)
那么
q
(
ϵ
)
q(\epsilon)
q(ϵ)我们就取标准正态分布
q
(
ϵ
)
=
N
(
0
,
1
)
\large q(\epsilon) = N(0,1)
q(ϵ)=N(0,1)
那么显然有
σ
θ
ϵ
+
μ
θ
∼
N
(
μ
θ
,
σ
θ
2
)
\large \sigma_\theta\epsilon+\mu_\theta \sim N(\mu_\theta,\sigma_\theta^2)
σθϵ+μθ∼N(μθ,σθ2)
所以我们就取
g
θ
(
ϵ
)
=
σ
θ
ϵ
+
μ
θ
\large g_\theta(\epsilon) = \sigma_\theta\epsilon+\mu_\theta
gθ(ϵ)=σθϵ+μθ
最后有
E
z
∼
N
(
μ
θ
,
σ
θ
2
)
[
f
θ
(
z
)
]
=
E
ϵ
∼
N
(
0
,
1
)
[
f
θ
(
σ
θ
ϵ
+
μ
θ
)
]
\large E_{z\sim N(\mu_\theta,\sigma_\theta^2)}[f_{\theta}(z)] = E_{\epsilon\sim N(0,1)}[f_\theta(\sigma_\theta\epsilon+\mu_\theta)]
Ez∼N(μθ,σθ2)[fθ(z)]=Eϵ∼N(0,1)[fθ(σθϵ+μθ)]
离散情况的重参数处理
上述过程处理的是分布为连续密度函数的情况,但我们也经常遇到离散分布的情况,这种该如何处理?
为做区分,我们换一种写法:
L
θ
=
E
y
∼
p
θ
(
y
)
[
f
θ
(
y
)
]
=
∑
y
p
θ
(
y
)
f
θ
(
y
)
(
2
)
\large L_{\theta} = E_{y\sim p_{\theta}(y)}[f_{\theta}(y)] = \sum_{y}p_\theta(y)f_\theta(y) \quad \quad (2)
Lθ=Ey∼pθ(y)[fθ(y)]=y∑pθ(y)fθ(y)(2)
一般来说,此时
y
y
y是可枚举的,它在大部分情况下都对应了一个k分类问题,也就是说,
y
y
y可以表示为
p
θ
(
y
)
=
s
o
f
t
m
a
x
(
o
1
,
o
2
,
.
.
.
o
k
)
y
=
1
∑
e
o
i
e
o
y
(
3
)
\large p_\theta(y) = softmax(o_1,o_2,...o_k)_y = \frac{1}{\sum e^{o_i}}e^{o_y}\quad(3)
pθ(y)=softmax(o1,o2,...ok)y=∑eoi1eoy(3)
其中
o
i
o_i
oi一般就是模型的logits,它当然也是关于参数
θ
\theta
θ的函数
还是同一个问题, ( 2 ) (2) (2)式直接用求和的形式是没法计算梯度的,我们还是得试试重参数方法。
所以现在问题就变成了:
找到一个合适的无参数分布 q ( ϵ ) q(\epsilon) q(ϵ)以及对应的变换 g θ ( ϵ ) g_\theta(\epsilon) gθ(ϵ)保证它服从 p θ p_\theta pθ这个分布
事实上也确实已经有对应的成果了,它叫做
Gumbel Max
取
ϵ
∼
U
(
0
,
1
)
\large \epsilon\sim U(0,1)
ϵ∼U(0,1)
对应的
q
θ
(
ϵ
)
q_{\theta}(\epsilon)
qθ(ϵ)为:
a
r
g
m
a
x
i
(
l
o
g
p
i
−
l
o
g
(
−
l
o
g
ϵ
i
)
)
i
=
1
k
(
4
)
\large argmax_i(log p_i-log(-log \epsilon_i))_{i=1}^{k}\quad \quad (4)
argmaxi(logpi−log(−logϵi))i=1k(4)
这里第
p
θ
(
i
)
p_{\theta}(i)
pθ(i)简记为
p
i
p_i
pi了
我们只需证明
(
3
)
(3)
(3)式与
(
4
)
(4)
(4)式是同一个分布,即
(
4
)
(4)
(4)式输出数字
i
i
i的概率为
p
i
p_i
pi
不失一般性地,我们考虑
(
4
)
(4)
(4)式输出数字1的概率:
此时意味着
l
o
g
p
1
−
l
o
g
(
−
l
o
g
ϵ
1
)
log p_1-log(-log \epsilon_1)
logp1−log(−logϵ1)是
1
−
k
1-k
1−k中最大的,即
l
o
g
p
1
−
l
o
g
(
−
l
o
g
ϵ
1
)
≥
l
o
g
p
i
−
l
o
g
(
−
l
o
g
ϵ
i
)
,
∀
i
∈
(
1
,
k
]
log p_1-log(-log \epsilon_1)\geq log p_i-log(-log \epsilon_i) ,\forall i\in (1,k]
logp1−log(−logϵ1)≥logpi−log(−logϵi),∀i∈(1,k]
得到
ϵ
i
≤
ϵ
1
p
i
/
p
1
≤
1
,
∀
i
∈
(
1
,
k
]
\large \epsilon_i\leq \epsilon_1^{p_i/p_1}\leq 1,\forall i\in (1,k]
ϵi≤ϵ1pi/p1≤1,∀i∈(1,k]
又
e
i
∼
U
(
0
,
1
)
e_i\sim U(0,1)
ei∼U(0,1),从而
P
(
ϵ
i
≤
ϵ
1
p
i
/
p
1
)
=
ϵ
1
p
i
/
p
1
,
∀
i
∈
(
1
,
k
]
\large P(\epsilon_i\leq \epsilon_1^{p_i/p_1})=\epsilon_1^{p_i/p_1},\forall i\in (1,k]
P(ϵi≤ϵ1pi/p1)=ϵ1pi/p1,∀i∈(1,k]
从而
(
4
)
(4)
(4)式输出1的概率为
P
(
ϵ
2
≤
ϵ
1
p
2
/
p
1
,
ϵ
3
≤
ϵ
1
p
3
/
p
1
,
.
.
.
ϵ
k
≤
ϵ
1
p
k
/
p
1
)
=
∏
i
=
2
k
ϵ
1
p
i
/
p
1
=
ϵ
1
(
1
−
p
1
)
/
p
1
\large P(\epsilon_2\leq \epsilon_1^{p_2/p_1},\epsilon_3\leq \epsilon_1^{p_3/p_1},...\epsilon_k\leq \epsilon_1^{p_k/p_1}) = \prod_{i=2}^{k}\epsilon_1^{p_i/p_1}=\epsilon_1^{(1-p_1)/p_1}
P(ϵ2≤ϵ1p2/p1,ϵ3≤ϵ1p3/p1,...ϵk≤ϵ1pk/p1)=i=2∏kϵ1pi/p1=ϵ1(1−p1)/p1
对
ϵ
1
\epsilon_1
ϵ1的所有情况求个平均,得到
∫
0
1
ϵ
1
(
1
−
p
1
)
/
p
1
d
ϵ
1
=
p
1
\large \int_0^1 \epsilon_1^{(1-p_1)/p_1}d\epsilon_1 = p_1
∫01ϵ1(1−p1)/p1dϵ1=p1
这就是
(
4
)
(4)
(4)式输出1的概率,它恰好为
p
1
p_1
p1
从而我们证明了
(
4
)
(4)
(4)式与
(
3
)
(3)
(3)式确实是同分布,所以我们就成功找到了合理的无参数分布
q
(
ϵ
)
q(\epsilon)
q(ϵ)以及对应的变换
g
θ
(
ϵ
)
g_\theta(\epsilon)
gθ(ϵ)
□
\square
□
那么所有过程似乎到这里就圆满结束了。
但是!但是,这里还是有点问题:argmax这个运算本身也是无法求导的…
也就是说,我们将求梯度运算转移到了
a
r
g
m
a
x
argmax
argmax运算上面,结果它还是没有办法求梯度?
不过没关系,这一步其实并不是很难处理。我们知道
a
r
g
m
a
x
argmax
argmax其实可以扩展成
o
n
e
_
h
o
t
(
a
r
g
m
a
x
)
one\_hot(argmax)
one_hot(argmax),而后者的一个光滑近似就是
s
o
f
t
m
a
x
softmax
softmax:对于这一点,我相信接触过蒸馏的同学肯定是很清楚的,我们只需要调整蒸馏的温度就能使得
s
o
f
t
m
a
x
softmax
softmax无限趋近于
o
n
t
_
h
o
t
ont\_hot
ont_hot
而
s
o
f
t
m
a
x
softmax
softmax显然是可以求梯度的,我们就顺利解决了这个遗留的问题。
这种策略被称为
Gumbel Softmax
具体来说,我们的
g
θ
(
ϵ
)
g_\theta(\epsilon)
gθ(ϵ)要改成:
s
o
f
t
m
a
x
i
(
(
l
o
g
p
i
−
l
o
g
(
−
l
o
g
ϵ
i
)
)
/
τ
)
i
=
1
k
(
5
)
\large softmax_i((log p_i-log(-log \epsilon_i))/\tau)_{i=1}^{k}\quad \quad (5)
softmaxi((logpi−log(−logϵi))/τ)i=1k(5)
其中
τ
\tau
τ就是蒸馏的温度,当
τ
→
0
\tau\rightarrow 0
τ→0的时候,
s
o
f
t
m
a
x
softmax
softmax就可以看成
o
n
t
_
h
o
t
ont\_hot
ont_hot,当然此时梯度消失现象也会很严重。
由此我们也可以得到训练策略:对参数
τ
\tau
τ进行退火,最后得到接近于
o
n
t
_
h
o
t
ont\_hot
ont_hot形式对应的结果。常见的一个退火策略为:
τ
p
=
τ
0
(
τ
p
/
τ
0
)
p
/
P
\large \tau_p = \tau_0(\tau_p/\tau_0)^{p/P}
τp=τ0(τp/τ0)p/P
其中
τ
p
\tau_p
τp是第
p
p
p次训练的温度,
τ
0
\tau_0
τ0是初始温度,
P
P
P是总轮数。
总结一下,对于总体的
k
k
k个情况,我们从0到1的均匀分布中取
k
k
k个值,利用Gumbel softmax得到一个
k
k
k维向量
p
~
\tilde{p}
p~,
那么
∑
y
p
~
y
f
θ
(
y
)
\sum_y \tilde{p}_yf_\theta(y)
y∑p~yfθ(y)
就是
L
θ
L_\theta
Lθ的一个良好估计,并且它成功保留了梯度信息
需要指出的是,Gumbel Max是原式的等价形式,但是Gumbel Softmax并不是,它是Gumbel Max的一个光滑近似,当 τ \tau τ足够小的时候,它可以近似看成Gumbel Max
顺便提一嘴这个东西为啥叫Gumbel Max/Softmax:
我们仔细观察
(
5
)
(5)
(5)式:
s
o
f
t
m
a
x
i
(
(
l
o
g
p
i
−
l
o
g
(
−
l
o
g
ϵ
i
)
)
/
τ
)
i
=
1
k
\large softmax_i((log p_i-log(-log \epsilon_i))/\tau)_{i=1}^{k}
softmaxi((logpi−log(−logϵi))/τ)i=1k
按照原本的思路,我们可以先从均匀分布里采样
ϵ
\epsilon
ϵ,然后再做log运算,再做log运算,再与
l
o
g
p
i
logp_i
logpi做差,不过实际上实际从一个
−
l
o
g
(
−
l
o
g
ϵ
)
-log(-log \epsilon)
−log(−logϵ)服从的分布里直接采样也是完全OK的,那我们就来看看这个分布长什么样子:
记
x
=
−
l
o
g
(
−
l
o
g
ϵ
)
x = -log(-log \epsilon)
x=−log(−logϵ)
那么
F
X
(
x
)
=
P
X
(
X
≤
x
)
=
P
ϵ
(
−
l
o
g
(
−
l
o
g
ϵ
)
≤
x
)
=
P
ϵ
(
ϵ
≤
e
−
e
−
x
)
=
F
ϵ
(
e
−
e
−
x
)
F_X(x) = P_X(X\leq x) = P_\epsilon(-log(-log \epsilon)\leq x) = P_\epsilon(\epsilon\leq e^{-e^{-x}}) = F_\epsilon(e^{-e^{-x}})
FX(x)=PX(X≤x)=Pϵ(−log(−logϵ)≤x)=Pϵ(ϵ≤e−e−x)=Fϵ(e−e−x)
从而
F
X
(
x
)
=
e
x
p
(
−
e
x
p
(
−
x
)
)
F_X(x) = exp(-exp(-x))
FX(x)=exp(−exp(−x))
这就是这个分布的累积分布函数,它就被称为Gumbel分布。实际上Gumbel分布还带有另外两个参数
F
X
(
x
,
μ
,
β
)
=
e
x
p
(
−
e
x
p
(
−
x
−
μ
β
)
)
F_X(x,\mu,\beta) = exp(-exp(-\frac{x-\mu}{\beta}))
FX(x,μ,β)=exp(−exp(−βx−μ))
也就是说这里是
μ
=
β
=
0
\mu=\beta=0
μ=β=0的特殊情况。不过这一点不必细讲,感兴趣的读者可以再去了解一下。
最后讲一个实现细节:
在求原分布
q
θ
q_\theta
qθ的时候,我们需要从
{
o
i
}
\{o_i\}
{oi}出发做softmax得到
{
p
i
}
\{p_i\}
{pi},但是实际上
(
5
)
(5)
(5)式可以直接替换为
s
o
f
t
m
a
x
i
(
(
o
i
−
l
o
g
(
−
l
o
g
ϵ
i
)
)
/
τ
)
i
=
1
k
\large softmax_i((o_i-log(-log \epsilon_i))/\tau)_{i=1}^{k}
softmaxi((oi−log(−logϵi))/τ)i=1k
那么我们就不必去做softmax了
至于证明其实也很简单:
l
o
g
p
i
=
l
o
g
(
s
o
f
t
m
a
x
(
o
i
)
)
=
l
o
g
(
e
o
i
∑
j
e
o
j
)
\large log p_i = log(softmax(o_i)) = log(\frac{e^{o_i}}{\sum_j e^{o_j}})
logpi=log(softmax(oi))=log(∑jeojeoi)
从而
l
o
g
p
i
=
o
i
−
C
logp_i = o_i-C
logpi=oi−C
从而
s
o
f
t
m
a
x
(
(
l
o
g
p
i
+
g
i
)
/
τ
)
=
e
(
l
o
g
p
i
+
g
i
)
/
τ
∑
j
e
(
l
o
g
p
j
+
g
j
)
/
τ
=
e
(
o
i
−
C
+
g
i
)
/
τ
∑
j
e
(
o
j
−
C
+
g
j
)
/
τ
softmax((logp_i+g_i)/\tau) = \frac{e^{(logp_i+g_i)/\tau}}{\sum_j e^{(logp_j+g_j)/\tau}} = \frac{e^{(o_i-C+g_i)/\tau}}{\sum_j e^{(o_j-C+g_j)/\tau}}
softmax((logpi+gi)/τ)=∑je(logpj+gj)/τe(logpi+gi)/τ=∑je(oj−C+gj)/τe(oi−C+gi)/τ
显然可以将常数
C
C
C对应的部分提出来
=
e
(
o
i
+
g
i
)
/
τ
∑
j
e
(
o
j
+
g
j
)
/
τ
=
s
o
f
t
m
a
x
(
(
o
i
+
g
i
)
/
τ
)
= \frac{e^{(o_i+g_i)/\tau}}{\sum_j e^{(o_j+g_j)/\tau}} = softmax((o_i+g_i)/\tau)
=∑je(oj+gj)/τe(oi+gi)/τ=softmax((oi+gi)/τ)
这里
g
i
g_i
gi就指之前讲的Gumbel分布
总结
以上就是重参数在连续和离散两个场景的应用了,它最初也是最多的应用应该是在VAE里面,我以后应该也会接触,到时候也许会对这篇文章加以补充。