1. 什么是对数梯度技巧?
对数梯度技巧的核心目标是计算目标函数的梯度,形式通常为期望:
J(θ)=Ep(x∣θ)[f(x)]=∫p(x∣θ)f(x) dx J(\theta) = \mathbb{E}_{p(x|\theta)}[f(x)] = \int p(x|\theta) f(x) \, dx J(θ)=Ep(x∣θ)[f(x)]=∫p(x∣θ)f(x)dx
其中:
- p(x∣θ)p(x|\theta)p(x∣θ) 是参数化的概率分布,依赖于参数 θ\thetaθ(例如神经网络权重或分布参数)。
- f(x)f(x)f(x) 是基于样本 xxx 的函数,例如损失、奖励或目标值。
- 我们需要计算 ∇θJ(θ)\nabla_\theta J(\theta)∇θJ(θ) 以通过梯度下降优化 θ\thetaθ。
直接计算梯度 ∇θJ(θ)\nabla_\theta J(\theta)∇θJ(θ) 涉及对概率分布 p(x∣θ)p(x|\theta)p(x∣θ) 求导,可能面临积分复杂、数值不稳定等问题。对数梯度技巧通过以下关键性质简化计算:
∇θlogp(x∣θ)=∇θp(x∣θ)p(x∣θ) \nabla_\theta \log p(x|\theta) = \frac{\nabla_\theta p(x|\theta)}{p(x|\theta)} ∇θlogp(x∣θ)=p(x∣θ)∇θp(x∣θ)
这允许我们将梯度表达为一个可通过蒙特卡洛采样估计的期望形式:
∇θJ(θ)=Ep(x∣θ)[∇θlogp(x∣θ)f(x)] \nabla_\theta J(\theta) = \mathbb{E}_{p(x|\theta)}[\nabla_\theta \log p(x|\theta) f(x)] ∇θJ(θ)=Ep(x∣θ)[∇θlogp(x∣θ)f(x)]
这个公式是对数梯度技巧的核心,广泛应用于需要优化概率分布的场景。
2. 为什么需要对数梯度技巧?
在优化概率模型时,目标函数通常涉及期望,形式如:
J(θ)=Ep(x∣θ)[f(x)] J(\theta) = \mathbb{E}_{p(x|\theta)}[f(x)] J(θ)=Ep(x∣θ)[f(x)]
直接计算梯度:
∇θJ(θ)=∇θ∫p(x∣θ)f(x) dx \nabla_\theta J(\theta) = \nabla_\theta \int p(x|\theta) f(x) \, dx ∇θJ(θ)=∇θ∫p(x∣θ)f(x)dx
会遇到以下挑战:
- 复杂积分:p(x∣θ)p(x|\theta)p(x∣θ) 可能是高维或非解析的,积分难以直接求解。
- 参数依赖:p(x∣θ)p(x|\theta)p(x∣θ) 依赖于 θ\thetaθ,梯度需要作用在整个积分上,计算复杂。
- 数值不稳定:直接操作概率密度可能导致数值溢出或下溢,尤其在概率值极小或极大时。
- 黑盒场景:在某些情况下(如商业 API),模型是黑盒的,无法访问内部参数或梯度。
对数梯度技巧通过对数导数性质将梯度转化为期望形式,允许通过采样估计,避免直接求解复杂积分,且对黑盒模型友好。
3. 对数梯度技巧的数学推导
以下是对数梯度技巧的详细数学推导,涵盖连续和离散分布。
3.1 目标函数
假设目标函数为:
J(θ)=Ep(x∣θ)[f(x)]=∫p(x∣θ)f(x) dx J(\theta) = \mathbb{E}_{p(x|\theta)}[f(x)] = \int p(x|\theta) f(x) \, dx J(θ)=Ep(x∣θ)[f(x)]=∫p(x∣θ)f(x)dx
我们需要计算梯度 ∇θJ(θ)\nabla_\theta J(\theta)∇θJ(θ)。对于离散分布,积分替换为求和,但推导原理类似。
3.2 对数导数性质
考虑概率密度 p(x∣θ)p(x|\theta)p(x∣θ) 的对数:
logp(x∣θ) \log p(x|\theta) logp(x∣θ)
对其求梯度:
∇θlogp(x∣θ)=∇θp(x∣θ)p(x∣θ) \nabla_\theta \log p(x|\theta) = \frac{\nabla_\theta p(x|\theta)}{p(x|\theta)} ∇θlogp(x∣θ)=p(x∣θ)∇θp(x∣θ)
因此:
∇θp(x∣θ)=p(x∣θ)∇θlogp(x∣θ) \nabla_\theta p(x|\theta) = p(x|\theta) \nabla_\theta \log p(x|\theta) ∇θp(x∣θ)=p(x∣θ)∇θlogp(x∣θ)
3.3 梯度推导
将目标函数的梯度展开:
∇θJ(θ)=∇θ∫p(x∣θ)f(x) dx \nabla_\theta J(\theta) = \nabla_\theta \int p(x|\theta) f(x) \, dx ∇θJ(θ)=∇</