理解DreamerV2中似然损失与KL散度损失

假设有一组有m个样本的集合 X=\left \{ x^{(1)},...,x^{(m)} \right \},由未知的真实数据分布 p_{data}(x)生成。而我们的模型可以表述为 p_{model}(x;\theta_{real} ) , \theta为我们模型的参数。p_{model}(x;\theta_{estimate} ) 估计真实的概率分布 p_{data}(x)

1. 最大似然估计

最大似然估计(Maximum Likelihood Estimation,MLE),又叫极大似然估计,是统计学中应用最广泛的一种未知参数估计方法。 它可以在已知随机变量属于哪种概率分布的前提下, 利用随机变量的一些观测值估计出分布的一些参数值。 所谓观测值,就是随机变量的采样值,也就是这个随机变量试验的真实结果值, 因为是我们能”看到”的值,所以称为观测值。

似然函数是一种关于统计模型中的参数的函数,表示模型参数的似然性。似然函数是给定联合样本值 x 下关于(未知)参数 θ 的函数:

使模型预测的分布p_{model}(x;\theta_{estimate} )可以和数据真实分布一致,就是要确定当前模型参数\theta使模型可以解释真实分布样本集X=\left \{ x^{(1)},...,x^{(m)} \right \}问题转换成了一个已知样本集估计参数\theta的问题,参数 \theta 的最大似然估计为:

\begin{aligned} & \hat{\theta}=\arg \max _\theta p_{\text {model }}(\mathbb{X} ; \theta) \\ & =\arg \max _\theta \prod_{i=1}^m p_{\text {model }}\left(x^{(i)} ; \theta\right) \end{aligned}

为了便于计算进行对数化,

\hat{\theta}=\arg \max _\theta \sum_{i=1}^m \log \left(p_{\text {model }}\left(x^{(i)} ; \theta\right)\right)

对上述公式我们可以除以 m 求平均,m趋于无穷大时,此时该式子变为将数据经验分布 $\hat{p}_{d a t a}$ 作为概率密度,求p_{model}(x;\theta )的期望,并最大化。

$ \hat{\theta}=\arg \max _\theta \mathbb{E}_{\mathrm{x} \sim \hat{\mathrm{p}}_{\text {data }}} \log \left(p_{\text {model }}(x ; \theta)\right) $,通过负对数似然就可以将这个问题转化为最小化问题。

当对数似然函数取得了最大值,我们认为预测模型趋近于真实模型,预测分布趋近于真实分布。

2. KL散度估计

\mathrm{KL}(P \| Q)=\sum_i P(i) \log \left(\frac{P(i)}{Q(i)}\right)

最大似然的另一种观点为最小化训练集上的经验分布$\hat{p}_{d a t a}$和模型分布之间的差异,这个差异我们通常使用KL散度进行描述:

$ D_{K L}\left(\hat{p}_{\text {data }} \| p_{\text {model }}\right)=\mathbb{E}_{x \sim \hat{p}_{\text {data }}}\left[\log \left(\hat{p}_{\text {data }}(x)\right)-\log \left(p_{\text {model }}(x)\right)\right] $

$\hat{p}_{d a t a}$为数据的经验分布,和模型无关,所以要最小化散度,我们只需要最小化:

$ -\mathbb{E}_{x \sim \hat{p}_{\text {data }}}\left[\log \left(p_{\text {model }}(x)\right)\right] $

显然最小化KL散度就是在最小化分布之间的交叉摘。(KL散度 - 目标分布熵 = 交叉熵)

所以给定样本集,用最大似然和最小化KL散度优化模型参数是等价的(都要假设模型输出的分

布,用样本集计算或代替真实分布)。

3. MSE

针对Z=X-Y,f_{Z}(z)=\int_{-\infty}^{+\infty} f_X(x) f_Y(x-z) d x,如果两随机变量独立则有:

f_Z(z \mid Y=y)=f_X(z+x)

假设预测分布和真实分布相互独立,且给定真实值分布条件下,模型预测值与真实值之间的误差服从标准高斯分布,Z\mid{X}=(X-\hat{X})\mid{X}\sim N(0,1) ,(\mu=0, \sigma=1)

$ p_{Z\mid{X}}(z)=\frac{1}{\sqrt{2 \pi}} \exp \left(-\frac{\left(z\right)^2}{2}\right)=p_{Z\mid{X}}\left(x- \hat{x} \right)= \frac{1}{\sqrt{2 \pi}} \exp \left(-\frac{\left(\hat{x} - x\right)^2}{2}\right)$

p_{\hat{X}}(z+x_{i})=p_Z(z \mid X=x_{i})=\frac{1}{\sqrt{2 \pi}} \exp \left(-\frac{\left(z\right)^2}{2}\right)

\hat{x}=z+x_{i}p_{\hat{X}}(\hat{x})=\frac{1}{\sqrt{2 \pi}} \exp \left(-\frac{\left(\hat{x}-x_{i}\right)^2}{2}\right)

已知真实分布样本时,求得了预测模型应该的目标分布。

将预测出的样本视为参数估计中采集的真实样本,添加最大似然约束,“让最优预测模型接近预测模型”,即让预测模型达到理论上能做到的最优预测模型,从而预测真实分布。最大化似然函数,

\hat{\theta}=\arg \max _\theta \sum_{i=1}^m \log \left(p_{\hat{x}}\left(\hat{x}=\hat{x}^{(i)} ; \theta\right)\right) \\ =\arg \max _\theta \sum_{i=1}^m \log\frac{1}{\sqrt{2 \pi}} \exp \left(-\frac{\left(\hat{x}^{(i)}-x_{i}\right)^2}{2}\right) \\ =\arg \max _\theta -\sum_{i=1}^m\left(\hat{x}^{(i)}-x_{i}\right)^2

该式子的对比均方误差MSE=\frac{1}{m}\sum_{i=1}^m\left(\hat{x}^{(i)}-x_{i}\right)^2,两者显然等价,最小化均方差和最大化似然函数会得到相同的最优参数\hat{\theta} 。从概率的角度来说MSE是基于预测分布和真实分布相互独立,且给定真实值分布条件下,模型预测值与真实值之间的误差服从标准高斯分布假设的最大似然。

4. MASTERING ATARI WITH DISCRETE WORLD MODELS论文

4.1 Gaussian Likelihood

用高斯模型生成数据,就是假设模型输出的数据分布 p_{model}(x;\theta )为高斯分布,

应用最大似然估计,$ \hat{\theta}=\arg \max _\theta \mathbb{E}_{\mathrm{x} \sim \hat{\mathrm{p}}_{\text {data }}} \log \left(p_{\text {model }}(x ; \theta)\right) $,使模型输出的数据分布接近真实分布。

论文中Gaussian Likelihood定义的image log loss和如下,

class LossModel(nn.Module):
    def __init__(self, nx=1/64/64/3, nr=1, ng=1, nt=0.08, nq=0.1):
        super(LossModel, self).__init__()

        self.nx = nx
        self.nr = nr
        self.ng = ng
        self.nt = nt
        self.nq = nq

    def forward(self, x, r, gamma, z_logits, z_sample, x_hat, r_hat, gamma_hat, z_hat_logits):
        x_dist = torch.distributions.normal.Normal(
            loc=x_hat,
            scale=1.0
        )
        r_dist = torch.distributions.normal.Normal(
            loc=r_hat,
            scale=1.0
        )
        gamma_dist = torch.distributions.bernoulli.Bernoulli(
            logits=gamma_hat
        )
        z_hat_dist = torch.distributions.one_hot_categorical.OneHotCategorical(
            logits=z_hat_logits.reshape(-1, 32, 32)
        )
        z_dist = torch.distributions.one_hot_categorical.OneHotCategorical(
            logits=z_logits.reshape(-1, 32, 32).detach()
        )

        z_sample = z_sample.reshape(-1, 32, 32)

        loss = -self.nx*x_dist.log_prob(x).mean()\      # 计算x在正态分布中对应的概率的对数
                -self.nr*r_dist.log_prob(r).mean()\
                -self.ng*gamma_dist.log_prob(gamma.round()).mean()\
                -self.nt*z_hat_dist.log_prob(z_sample.detach()).mean()\
                +self.nq*z_dist.log_prob(z_sample).mean()

        return loss

根据公布的源码,image log loss计算步骤是:

        1. 图像预测器输出具有单位方差的对角高斯似然的均值,即p_{model}(x;\theta )的高维均值,方差为1,这样可以构成重构image数据的分布p_{model}(x;\theta_{estimate} )

        2. 将- \mathbb{E}_{\mathrm{x} \sim \hat{\mathrm{p}}_{\text {data }}} \log \left(p_{\text {model }}(x ; \theta)\right)作为损失函数,计算梯度更新image predictor,更新image predictor网络的参数,更新网络的输出(即image predictor高斯模型的均值参数),从而使高斯模型输出分布p_{model}(x;\theta )向着真实分布逼近。

4.2 类别分布隐变量的KL散度

为使 \hat{Z} 分布和 Z 分布接近,使用 KL 散度作为损失。在计算 $Z$ 和 $\hat{Z}$ 的 KL 散度时,首先需要将它们的概率分布转换为概率密度函数。在代码中,使用了 torch.distributions.one_hot_categorical.OneHotCategorical 来表示这些离散分布,然后通过 sample() 方法获得了样本,再利用 torch.softmax 转换为概率分布。

在计算 KL 散度时,需要对概率分布取对数。然后,对于每个类别的概率,计算其与 $z$ 和 $z_{\text{hat}}$ 分布的比值,并乘以 $z$ 分布的概率。最后对所有类别进行求和,得到 KL 散度的值。

具体来说,对于 \hat{Z} 和 Z ,它们的 KL 散度计算步骤如下:

  1. self.representation_model_mlp和self.transition_predictor网络分别输出 Z 和 \hat{Z} 的event log probabilities (unnormalized),z_logits和z_hat_logits。
  2. 将 Z 和 \hat{Z} 的 logits 利用torch.distributions.one_hot_categorical.OneHotCategorical 转换为概率分布,z_dist和z_hat_dist
  3. 计算 KL 散度时,$ D_{K L}\left(\hat{p}_{\text {data }} \| p_{\text {model }}\right)=\mathbb{E}_{x \sim \hat{p}_{\text {data }}}\left[\log \left(\hat{p}_{\text {data }}(x)\right)-\log \left(p_{\text {model }}(x)\right)\right] $,分成两部分似然进行计算,使用 Z 分布采样出样本z_sample,当成似然函数的样本代入分布计算。
  4. 为了保持采样z_sample时梯度的连续性,需要使用 straight-through estimator 的技巧,即
    z_sample = torch.distributions.one_hot_categorical.OneHotCategorical(
                logits=z_logits.reshape(-1, 32, 32)).sample()
    z_probs = torch.softmax(z_logits.reshape(-1, 32, 32), dim=-1)
    z_sample = z_sample + z_probs - z_probs.detach()
    通过将 \hat{Z} 的采样样本和概率相加,然后减去概率的 detach 值来实现。
  5. 计算 KL 散度。在代码中,KL 散度的计算实际上是计算了两个分布之间的交叉熵(cross entropy),再加上一些系数项,因为 KL 散度可以被定义为两个分布之间的交叉熵减去一个熵项。(代码简化了原文中的kl balancing,直接使用了 log likelihood代替了)
    KL_loss = -self.nt*z_hat_dist.log_prob(z_sample.detach()).mean()\
                    +self.nq*z_dist.log_prob(z_sample).mean()

4.3 一些思考

1. 为什么用高斯似然损失,不用KL散度,也不用MSE描述image和重构的image损失?

(1)KL散度与似然函数等价,因为world model将image predictor建模为高斯分布,没有网络将原始image样本解码高斯分布的参数,所以没法计算KL散度;

(2)将image predictor建模为高斯分布后,如果直接用image和重构image的MSE,则相当于没有用到这个假设。

2. 为什么类别隐变量之间用KL散度,不用类别隐变量之间直接做MSE?

因为直接用MSE的话,没有用到类别分布的假设。

3. 当模型输出建模为类别分布时,最大化类别似然函数是否等价于最小化交叉熵损失?

最大化似然等价于最小化KL散度,即$ \max _\theta \mathbb{E}_{\mathrm{x} \sim \hat{\mathrm{p}}_{\text {data }}} \log \left(p_{\text {model }}(x ; \theta)\right) $,网络的输出为类别分布的概率参数,以此构造类别分布,再结合真实样本标签计算似然函数

x_dist = torch.distributions.one_hot_categorical.OneHotCategorical(
            logits=x_logits.detach()
        )

loss = -x_dist.log_prob(x).mean()

只计算对应标签处的log probability,和cross entropy计算是等价的。

理论推导:

$ \mathbb{E}_{\mathrm{x} \sim \hat{\mathrm{p}}_{\text {data }}} \log \left(p_{\text {model }}(x ; \theta)\right)= \mathbb{E}_{\mathrm{y} \sim p(Y)} \log \left(p(\hat{Y})\right)=\frac{1}{m}\sum_{i=1}^{m}\log \left(p(\hat{Y}=y_{i})\right) $

y_{i}是one-hot编码时,

=\frac{1}{m}\sum_{i=1}^{m}\sum_{k=1}^{K}y_{i}^{k}\log \left(p(\hat{Y}=y_{i}^{k})\right),K为类别数

与交叉熵一致

Reference

1. https://zhuanlan.zhihu.com/p/384420701

2. https://zhuanlan.zhihu.com/p/266049554

  • 51
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值