Reparameterization Trick

之前在学蒸馏的时候接触了gumbel-softmax,顺势了解了一下重参数技巧,还是很有意思的一个东西

引入

重参数技巧主要是尝试对这样形式的一个东西求梯度
L θ = E z ∼ p θ ( z ) [ f θ ( z ) ] ( 1 ) \large L_{\theta} = E_{z\sim p_{\theta}(z)}[f_{\theta}(z)] \quad \quad(1) Lθ=Ezpθ(z)[fθ(z)](1)
其中 z ∼ p θ ( z ) z\sim p_{\theta}(z) zpθ(z)表示随机变量 z z z服从概率密度函数 p θ ( z ) p_{\theta}(z) pθ(z),显然这个密度函数是跟模型参数 θ \theta θ有关的 f θ ( z ) f_{\theta}(z) fθ(z)一般可以表示模型某一层关于变量 z z z的输出,显然它也跟模型参数 θ \theta θ有关

不妨先来想想这个式子要如何处理。一个非常naive的思路:采样估计。但是如果直接采样的话,每次采样我们只能获得 ∇ θ f θ ( z ) \nabla_\theta f_\theta(z) θfθ(z),而不同样本之间的信息是无法共用的,我们也就无从得到 ∇ θ L θ \nabla_\theta L_\theta θLθ。所以我们想想看,有没有什么好的处理方法,能在估计出 ( 1 ) (1) (1)式的同时还能保留梯度信息

不妨先来做一个简化,我们先假设 p θ ( z ) p_{\theta}(z) pθ(z)是一个跟 θ \theta θ无关的概率密度函数,简记为 p ( z ) p(z) p(z),我们很快注意到现在是可以采样估计梯度了:
∇ θ L θ = ∇ θ E z ∼ p ( z ) [ f θ ( z ) ] = ∇ θ [ ∫ z p ( z ) f θ ( z ) d z ] = ∫ z p ( z ) ∇ θ f θ ( z ) d z = E z ∼ p ( z ) [ ∇ θ f θ ( z ) ] \large \nabla_{\theta} L_{\theta} = \nabla_{\theta}E_{z\sim p(z)}[f_{\theta}(z)] = \nabla_{\theta}[\int_z p(z)f_{\theta}(z)dz]\\ =\int_z p(z)\nabla_{\theta}f_{\theta}(z)dz\\ =E_{z\sim p(z)}[\nabla_{\theta}f_{\theta}(z)] θLθ=θEzp(z)[fθ(z)]=θ[zp(z)fθ(z)dz]=zp(z)θfθ(z)dz=Ezp(z)[θfθ(z)]
从而
∇ θ L θ ≈ 1 n ∑ i = 1 n ∇ θ f θ ( z i ) , z i ∼ p ( z ) \large \nabla_{\theta} L_{\theta} \approx \frac{1}{n}\sum_{i=1}^{n} \nabla_{\theta}f_{\theta}(z_i),z_i\sim p(z) θLθn1i=1nθfθ(zi),zip(z)
这是因为求梯度的操作成功转移到了 f θ ( z ) f_\theta(z) fθ(z)上面
上述过程可以用一句话来总结:期望的梯度等于梯度的期望

那我们回到 p θ ( z ) p_\theta(z) pθ(z),并尝试类似的步骤:
∇ θ L θ = ∇ θ E z ∼ p θ ( z ) [ f θ ( z ) ] = ∇ θ [ ∫ z p θ ( z ) f θ ( z ) d z ] = ∫ z p θ ( z ) ∇ θ f θ ( z ) d z + ∫ z ∇ θ p θ ( z ) f θ ( z ) d z = E z ∼ p θ ( z ) [ ∇ θ f θ ( z ) ] + ∫ z ∇ θ p θ ( z ) f θ ( z ) d z ⏟ ? ? ? \large \nabla_{\theta} L_{\theta} = \nabla_{\theta}E_{z\sim p_\theta(z)}[f_{\theta}(z)] = \nabla_{\theta}[\int_z p_\theta(z)f_{\theta}(z)dz]\\ =\int_z p_\theta(z)\nabla_{\theta}f_{\theta}(z)dz+\int_z \nabla_{\theta}p_\theta(z)f_{\theta}(z)dz\\ =E_{z\sim p_\theta(z)}[\nabla_{\theta}f_{\theta}(z)]+\underbrace{\int_z \nabla_{\theta}p_\theta(z)f_{\theta}(z)dz}_{???} θLθ=θEzpθ(z)[fθ(z)]=θ[zpθ(z)fθ(z)dz]=zpθ(z)θfθ(z)dz+zθpθ(z)fθ(z)dz=Ezpθ(z)[θfθ(z)]+??? zθpθ(z)fθ(z)dz
前面一块还是可以仿照之前的处理的,但是后者就显得比较诡异了,求梯度操作转移到 p θ ( z ) p_\theta(z) pθ(z)上面去,也就意味着我们无法将其整理成正常的关于某个东西的期望的形式。或许我们可以将 ∇ θ p θ ( z ) \nabla_{\theta}p_\theta(z) θpθ(z)求出来,但在大部分情况下这是不现实的。

此时就可以引入重参数技巧了

重参数

顾名思义,我们需要引入新的参数来处理上述问题:
考虑一个新的无参数分布
ϵ ∼ q ( ϵ ) \large \epsilon\sim{q(\epsilon)} ϵq(ϵ)
以及变换
z = g θ ( ϵ ) \large z = g_\theta(\epsilon) z=gθ(ϵ)
保证变换之后得到的 z z z服从 p θ p_\theta pθ
那么对 ( 1 ) (1) (1)式求梯度可以变成:
∇ θ L θ = ∇ θ E z ∼ p θ ( z ) [ f θ ( z ) ] = E ϵ ∼ q ( ϵ ) [ f θ ( g θ ( ϵ ) ) ] ( a ) = E ϵ ∼ q ( ϵ ) [ ∇ θ f θ ( g θ ( ϵ ) ) ]     ( b ) \large \nabla_{\theta} L_{\theta} = \nabla_{\theta}E_{z\sim p_\theta(z)}[f_{\theta}(z)] \\ = E_{\epsilon\sim q(\epsilon)}[f_\theta(g_\theta(\epsilon))]\quad \quad (a)\\ =E_{\epsilon\sim q(\epsilon)}[\nabla_{\theta}f_\theta(g_\theta(\epsilon))]\ \ \ (b) θLθ=θEzpθ(z)[fθ(z)]=Eϵq(ϵ)[fθ(gθ(ϵ))](a)=Eϵq(ϵ)[θfθ(gθ(ϵ))]   (b)
从而
∇ θ L θ ≈ 1 n ∑ i = 1 n ∇ θ f θ ( g θ ( ϵ i ) ) , ϵ i ∼ q ( ϵ ) \large \nabla_{\theta} L_{\theta} \approx \frac{1}{n}\sum_{i=1}^{n} \nabla_{\theta}f_\theta(g_\theta(\epsilon_i)),\epsilon_i\sim q(\epsilon) θLθn1i=1nθfθ(gθ(ϵi)),ϵiq(ϵ)

我们就成功实现了在采样的同时保持了梯度

注意,在这个过程中最重要的一步转化就是:
L θ = E ϵ ∼ q ( ϵ ) [ f θ ( g θ ( ϵ ) ) ] \large L_\theta = E_{\epsilon\sim q(\epsilon)}[f_\theta(g_\theta(\epsilon))] Lθ=Eϵq(ϵ)[fθ(gθ(ϵ))]
它将随机性从参数 θ \theta θ转移到了内部无参数的 ϵ \epsilon ϵ上面,从而可以利用我们之前讨论过的对无参数分布(或者说无可变参数)而言成立的“期望的梯度等于梯度的期望”这一性质来处理

例子

不妨就取 p θ ( z ) p_\theta(z) pθ(z)是一个正态分布,即
p θ ( z ) = N ( μ θ , σ θ 2 ) \large p_\theta(z) = N(\mu_\theta,\sigma_\theta^2) pθ(z)=N(μθ,σθ2)
那么 q ( ϵ ) q(\epsilon) q(ϵ)我们就取标准正态分布
q ( ϵ ) = N ( 0 , 1 ) \large q(\epsilon) = N(0,1) q(ϵ)=N(0,1)
那么显然有
σ θ ϵ + μ θ ∼ N ( μ θ , σ θ 2 ) \large \sigma_\theta\epsilon+\mu_\theta \sim N(\mu_\theta,\sigma_\theta^2) σθϵ+μθN(μθ,σθ2)
所以我们就取
g θ ( ϵ ) = σ θ ϵ + μ θ \large g_\theta(\epsilon) = \sigma_\theta\epsilon+\mu_\theta gθ(ϵ)=σθϵ+μθ
最后有
E z ∼ N ( μ θ , σ θ 2 ) [ f θ ( z ) ] = E ϵ ∼ N ( 0 , 1 ) [ f θ ( σ θ ϵ + μ θ ) ] \large E_{z\sim N(\mu_\theta,\sigma_\theta^2)}[f_{\theta}(z)] = E_{\epsilon\sim N(0,1)}[f_\theta(\sigma_\theta\epsilon+\mu_\theta)] EzN(μθ,σθ2)[fθ(z)]=EϵN(0,1)[fθ(σθϵ+μθ)]

离散情况的重参数处理

上述过程处理的是分布为连续密度函数的情况,但我们也经常遇到离散分布的情况,这种该如何处理?
为做区分,我们换一种写法:
L θ = E y ∼ p θ ( y ) [ f θ ( y ) ] = ∑ y p θ ( y ) f θ ( y ) ( 2 ) \large L_{\theta} = E_{y\sim p_{\theta}(y)}[f_{\theta}(y)] = \sum_{y}p_\theta(y)f_\theta(y) \quad \quad (2) Lθ=Eypθ(y)[fθ(y)]=ypθ(y)fθ(y)(2)
一般来说,此时 y y y是可枚举的,它在大部分情况下都对应了一个k分类问题,也就是说, y y y可以表示为
p θ ( y ) = s o f t m a x ( o 1 , o 2 , . . . o k ) y = 1 ∑ e o i e o y ( 3 ) \large p_\theta(y) = softmax(o_1,o_2,...o_k)_y = \frac{1}{\sum e^{o_i}}e^{o_y}\quad(3) pθ(y)=softmax(o1,o2,...ok)y=eoi1eoy(3)
其中 o i o_i oi一般就是模型的logits,它当然也是关于参数 θ \theta θ的函数

还是同一个问题, ( 2 ) (2) (2)式直接用求和的形式是没法计算梯度的,我们还是得试试重参数方法。

所以现在问题就变成了:

找到一个合适的无参数分布 q ( ϵ ) q(\epsilon) q(ϵ)以及对应的变换 g θ ( ϵ ) g_\theta(\epsilon) gθ(ϵ)保证它服从 p θ p_\theta pθ这个分布

事实上也确实已经有对应的成果了,它叫做

Gumbel Max


ϵ ∼ U ( 0 , 1 ) \large \epsilon\sim U(0,1) ϵU(0,1)
对应的 q θ ( ϵ ) q_{\theta}(\epsilon) qθ(ϵ)为:
a r g m a x i ( l o g p i − l o g ( − l o g ϵ i ) ) i = 1 k ( 4 ) \large argmax_i(log p_i-log(-log \epsilon_i))_{i=1}^{k}\quad \quad (4) argmaxi(logpilog(logϵi))i=1k(4)
这里第 p θ ( i ) p_{\theta}(i) pθ(i)简记为 p i p_i pi
我们只需证明 ( 3 ) (3) (3)式与 ( 4 ) (4) (4)式是同一个分布,即 ( 4 ) (4) (4)式输出数字 i i i的概率为 p i p_i pi

不失一般性地,我们考虑 ( 4 ) (4) (4)式输出数字1的概率:
此时意味着 l o g p 1 − l o g ( − l o g ϵ 1 ) log p_1-log(-log \epsilon_1) logp1log(logϵ1) 1 − k 1-k 1k中最大的,即
l o g p 1 − l o g ( − l o g ϵ 1 ) ≥ l o g p i − l o g ( − l o g ϵ i ) , ∀ i ∈ ( 1 , k ] log p_1-log(-log \epsilon_1)\geq log p_i-log(-log \epsilon_i) ,\forall i\in (1,k] logp1log(logϵ1)logpilog(logϵi),i(1,k]
得到
ϵ i ≤ ϵ 1 p i / p 1 ≤ 1 , ∀ i ∈ ( 1 , k ] \large \epsilon_i\leq \epsilon_1^{p_i/p_1}\leq 1,\forall i\in (1,k] ϵiϵ1pi/p11,i(1,k]
e i ∼ U ( 0 , 1 ) e_i\sim U(0,1) eiU(0,1),从而
P ( ϵ i ≤ ϵ 1 p i / p 1 ) = ϵ 1 p i / p 1 , ∀ i ∈ ( 1 , k ] \large P(\epsilon_i\leq \epsilon_1^{p_i/p_1})=\epsilon_1^{p_i/p_1},\forall i\in (1,k] P(ϵiϵ1pi/p1)=ϵ1pi/p1,i(1,k]
从而 ( 4 ) (4) (4)式输出1的概率为
P ( ϵ 2 ≤ ϵ 1 p 2 / p 1 , ϵ 3 ≤ ϵ 1 p 3 / p 1 , . . . ϵ k ≤ ϵ 1 p k / p 1 ) = ∏ i = 2 k ϵ 1 p i / p 1 = ϵ 1 ( 1 − p 1 ) / p 1 \large P(\epsilon_2\leq \epsilon_1^{p_2/p_1},\epsilon_3\leq \epsilon_1^{p_3/p_1},...\epsilon_k\leq \epsilon_1^{p_k/p_1}) = \prod_{i=2}^{k}\epsilon_1^{p_i/p_1}=\epsilon_1^{(1-p_1)/p_1} P(ϵ2ϵ1p2/p1,ϵ3ϵ1p3/p1,...ϵkϵ1pk/p1)=i=2kϵ1pi/p1=ϵ1(1p1)/p1
ϵ 1 \epsilon_1 ϵ1的所有情况求个平均,得到
∫ 0 1 ϵ 1 ( 1 − p 1 ) / p 1 d ϵ 1 = p 1 \large \int_0^1 \epsilon_1^{(1-p_1)/p_1}d\epsilon_1 = p_1 01ϵ1(1p1)/p1dϵ1=p1
这就是 ( 4 ) (4) (4)式输出1的概率,它恰好为 p 1 p_1 p1
从而我们证明了 ( 4 ) (4) (4)式与 ( 3 ) (3) (3)式确实是同分布,所以我们就成功找到了合理的无参数分布 q ( ϵ ) q(\epsilon) q(ϵ)以及对应的变换 g θ ( ϵ ) g_\theta(\epsilon) gθ(ϵ) □ \square
那么所有过程似乎到这里就圆满结束了。

但是!但是,这里还是有点问题:argmax这个运算本身也是无法求导的…
也就是说,我们将求梯度运算转移到了 a r g m a x argmax argmax运算上面,结果它还是没有办法求梯度?
不过没关系,这一步其实并不是很难处理。我们知道 a r g m a x argmax argmax其实可以扩展成 o n e _ h o t ( a r g m a x ) one\_hot(argmax) one_hot(argmax),而后者的一个光滑近似就是 s o f t m a x softmax softmax:对于这一点,我相信接触过蒸馏的同学肯定是很清楚的,我们只需要调整蒸馏的温度就能使得 s o f t m a x softmax softmax无限趋近于 o n t _ h o t ont\_hot ont_hot
s o f t m a x softmax softmax显然是可以求梯度的,我们就顺利解决了这个遗留的问题。
这种策略被称为

Gumbel Softmax

具体来说,我们的 g θ ( ϵ ) g_\theta(\epsilon) gθ(ϵ)要改成:
s o f t m a x i ( ( l o g p i − l o g ( − l o g ϵ i ) ) / τ ) i = 1 k ( 5 ) \large softmax_i((log p_i-log(-log \epsilon_i))/\tau)_{i=1}^{k}\quad \quad (5) softmaxi((logpilog(logϵi))/τ)i=1k(5)
其中 τ \tau τ就是蒸馏的温度,当 τ → 0 \tau\rightarrow 0 τ0的时候, s o f t m a x softmax softmax就可以看成 o n t _ h o t ont\_hot ont_hot,当然此时梯度消失现象也会很严重。
由此我们也可以得到训练策略:对参数 τ \tau τ进行退火,最后得到接近于 o n t _ h o t ont\_hot ont_hot形式对应的结果。常见的一个退火策略为:
τ p = τ 0 ( τ p / τ 0 ) p / P \large \tau_p = \tau_0(\tau_p/\tau_0)^{p/P} τp=τ0(τp/τ0)p/P
其中 τ p \tau_p τp是第 p p p次训练的温度, τ 0 \tau_0 τ0是初始温度, P P P是总轮数。


总结一下,对于总体的 k k k个情况,我们从0到1的均匀分布中取 k k k个值,利用Gumbel softmax得到一个 k k k维向量 p ~ \tilde{p} p~,
那么
∑ y p ~ y f θ ( y ) \sum_y \tilde{p}_yf_\theta(y) yp~yfθ(y)
就是 L θ L_\theta Lθ的一个良好估计,并且它成功保留了梯度信息

需要指出的是,Gumbel Max是原式的等价形式,但是Gumbel Softmax并不是,它是Gumbel Max的一个光滑近似,当 τ \tau τ足够小的时候,它可以近似看成Gumbel Max

顺便提一嘴这个东西为啥叫Gumbel Max/Softmax:
我们仔细观察 ( 5 ) (5) (5)式:
s o f t m a x i ( ( l o g p i − l o g ( − l o g ϵ i ) ) / τ ) i = 1 k \large softmax_i((log p_i-log(-log \epsilon_i))/\tau)_{i=1}^{k} softmaxi((logpilog(logϵi))/τ)i=1k
按照原本的思路,我们可以先从均匀分布里采样 ϵ \epsilon ϵ,然后再做log运算,再做log运算,再与 l o g p i logp_i logpi做差,不过实际上实际从一个 − l o g ( − l o g ϵ ) -log(-log \epsilon) log(logϵ)服从的分布里直接采样也是完全OK的,那我们就来看看这个分布长什么样子:

x = − l o g ( − l o g ϵ ) x = -log(-log \epsilon) x=log(logϵ)
那么
F X ( x ) = P X ( X ≤ x ) = P ϵ ( − l o g ( − l o g ϵ ) ≤ x ) = P ϵ ( ϵ ≤ e − e − x ) = F ϵ ( e − e − x ) F_X(x) = P_X(X\leq x) = P_\epsilon(-log(-log \epsilon)\leq x) = P_\epsilon(\epsilon\leq e^{-e^{-x}}) = F_\epsilon(e^{-e^{-x}}) FX(x)=PX(Xx)=Pϵ(log(logϵ)x)=Pϵ(ϵeex)=Fϵ(eex)
从而
F X ( x ) = e x p ( − e x p ( − x ) ) F_X(x) = exp(-exp(-x)) FX(x)=exp(exp(x))
这就是这个分布的累积分布函数,它就被称为Gumbel分布。实际上Gumbel分布还带有另外两个参数
F X ( x , μ , β ) = e x p ( − e x p ( − x − μ β ) ) F_X(x,\mu,\beta) = exp(-exp(-\frac{x-\mu}{\beta})) FX(x,μ,β)=exp(exp(βxμ))
也就是说这里是 μ = β = 0 \mu=\beta=0 μ=β=0的特殊情况。不过这一点不必细讲,感兴趣的读者可以再去了解一下。

最后讲一个实现细节:
在求原分布 q θ q_\theta qθ的时候,我们需要从 { o i } \{o_i\} {oi}出发做softmax得到 { p i } \{p_i\} {pi},但是实际上 ( 5 ) (5) (5)式可以直接替换为
s o f t m a x i ( ( o i − l o g ( − l o g ϵ i ) ) / τ ) i = 1 k \large softmax_i((o_i-log(-log \epsilon_i))/\tau)_{i=1}^{k} softmaxi((oilog(logϵi))/τ)i=1k
那么我们就不必去做softmax了
至于证明其实也很简单:
l o g p i = l o g ( s o f t m a x ( o i ) ) = l o g ( e o i ∑ j e o j ) \large log p_i = log(softmax(o_i)) = log(\frac{e^{o_i}}{\sum_j e^{o_j}}) logpi=log(softmax(oi))=log(jeojeoi)
从而
l o g p i = o i − C logp_i = o_i-C logpi=oiC
从而
s o f t m a x ( ( l o g p i + g i ) / τ ) = e ( l o g p i + g i ) / τ ∑ j e ( l o g p j + g j ) / τ = e ( o i − C + g i ) / τ ∑ j e ( o j − C + g j ) / τ softmax((logp_i+g_i)/\tau) = \frac{e^{(logp_i+g_i)/\tau}}{\sum_j e^{(logp_j+g_j)/\tau}} = \frac{e^{(o_i-C+g_i)/\tau}}{\sum_j e^{(o_j-C+g_j)/\tau}} softmax((logpi+gi)/τ)=je(logpj+gj)/τe(logpi+gi)/τ=je(ojC+gj)/τe(oiC+gi)/τ
显然可以将常数 C C C对应的部分提出来
= e ( o i + g i ) / τ ∑ j e ( o j + g j ) / τ = s o f t m a x ( ( o i + g i ) / τ ) = \frac{e^{(o_i+g_i)/\tau}}{\sum_j e^{(o_j+g_j)/\tau}} = softmax((o_i+g_i)/\tau) =je(oj+gj)/τe(oi+gi)/τ=softmax((oi+gi)/τ)
这里 g i g_i gi就指之前讲的Gumbel分布

总结

以上就是重参数在连续和离散两个场景的应用了,它最初也是最多的应用应该是在VAE里面,我以后应该也会接触,到时候也许会对这篇文章加以补充。

  • 16
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值