JSA是一种用于无监督隐变量学习的算法。以下是我个人对相关算法的理解。
背景
假设我们有一批数据,其中数据样本用表示,每个数据样本都对应一个隐变量
,那么我们的建模目标就是两者的联合概率
。一般情况下我们不会直接对上述联合概率进行参数化,而是会将其拆解为:
。其中条件分布
表示的是生成模型,能从隐变量生成数据样本等,先验分布
一般可由我们自己假设,比如我们可以假设隐变量符合一个高斯分布。
然而,很多时候每个数据样本的隐变量是未知的且没有标注的,因此我们需要找到一种无监督学习方法来训练上述的生成模型
。JSA就是这样一种方法,但在介绍JSA之前,先介绍一下更广为人知的变分学习(Variational Learning, VL)方法。
变分学习
如果我们能从数据样本中推出对应的隐变量
,那么我们就能对生成模型
进行训练。变分学习的核心就是用一个神经网络推断模型
来近似真实的后验分布
。
对于任何一对,由于我们只能观测到
,因此优化目标就是最大化
,进一步,其优化目标推导如下(读者可自行跳到后续加粗字段):
左右两边同时在的分布下对
做积分,由于左侧
与
无关,因此不变,上式进一步转化为:
注意KL散度是大于等于0的,所以上述的号才成立。之所以要舍弃上式中的KL项,是因为
是未知的(
只对
进行参数化)。于是我们的优化目标从原来的最大化
转变为了最大化其下界,用ELBO表示。我们进一步观察ELBO,如下:
最后一个式子即为一般情况下变分学习的优化目标,推导到此结束。该推导结果其实是非常直观的,我们总结一下变分学习优化目标中的两项分别在干什么,是怎样实现隐变量的无监督学习的
- 第一项最大化在
分布下对数概率
的期望。根据蒙特卡洛方法,我们可以直接先从
中采出隐变量
,然后再将
带入到
中并最大化其对数概率即可。这就实现了在隐变量
未知的情况下优化生成模型的目的。
- 第二项是最小化
同先验分布的KL距离。这一项本质是对
的约束,不然我们无法保证第一项中从
中采出的隐变量是“可靠”的。
联合随机近似(JSA)
随机近似(SA)
联合随机近似的本质其实是SAEM+MCMC。SAEM指的是用随机近似(Stochastic Approximation, SA)算法来解决最大似然估计,MCMC指的是马尔可夫蒙特卡洛算法。MCMC方法介绍起来比较复杂,读者可自行了解,下面仅介绍SA算法。
SA算法主要解决的是特殊形式的求根问题,即找到满足下式的:
下面的具体算法读者也可以跳过,只需要知道任何满足上式的方程都能用SA方法来求解即可。
SA算法主要可分为两步,第一步是利用蒙特卡洛方法从采样,第二步是将采出的样本代入
,利用计算得到的函数值更新
,更具体的算法可参考下图。
SA框架下的优化推导
与变分学习相同,对于隐变量模型,我们只能以边缘分布
的对数为优化目标。为了使用SA算法,我们可以将原先求最值的优化目标转化为导数等于0的求根问题,即:
进一步,我们需要把上面左式子转化为期望形式,如下:
这个式子看着比较绕,但并不复杂。这里用到了以及Fisher Identity
。
于是我们的优化目标就转化为了标准的可以用SA算法求解的求根问题:
这里可以根据实际应用需求选择是否拆解为生成模型分布和先验分布的乘积。
根据SA算法,我们需要从分布中采样,但在前面已经说过
是未知的,所以问题的关键就变成了如何从
中采样。
JSA选择采用MCMC方法对进行采样。MCMC方法有很多子方法,JSA采用其中的MIS方法(Metropolis independence sampler)。MIS方法需要一个提议分布(proposal),我们同样用
表示,模型从提议分布中采出隐变量后,再以一定的概率选择接收该样本,如此迭代一定步数后得到一个马尔科夫链,只要这个链足够长,我们就认为链上的样本是来自于目标分布的样本。具体而言,对于第
步MIS,可分为两个步骤:
- 从提议分布中采样
- 以概率
接收
,否则
同样的,上述接收概率里的是无法计算的,但我们可以将分子分母同时乘以
,从而将接收概率变为:
,可以直接计算得到。
由此,我们将上式求根方程以及采样方法代入到SA算法中即可实现对参数的更新。
上面介绍了如何在SA框架下解决无监督隐变量学习的问题。但一般在该框架下,提议分布是固定的,也就是其要没没有神经网络参数,要么参数不会被更新。JSA的核心idea是在训练过程中联合优化上述的提议分布
。优化方式则是用MIS采样得到的隐变量
和对应的
去最大化
。
总结一下,上述SA算法和变分学习有很多相似之处。他们的本质不同是,变分学习使用的是中的样本来优化
,而SA采用的是
中的样本来优化;变分学习需要对
加一个约束项,而SA则需要使用MCMC方法才能从未知的
采样。JSA则是在SA算法上的改进,让其中的提议分布也能被更新。
JSA算法的工程实现
利用MCMC方法采样,往往需要成千上万的步数才能采到足够接近目标分布的样本,尤其是对于一些高维分布。理论上对于每一个样本,我们都需要这么多的采样步数才能得到符合条件的隐变量
,这种做法在现在看来肯定是不现实的。JSA相关工作中提出了一种trick来解决这种问题。具体而言,在每次遍历到样本
的时候,MCMC算法只向前走一步,然后把得到的
保存下来,在下一次再遍历到这个样本时,在已保存的
的基础上再向前进行一步MCMC算法。但笔者思考后觉得该做法存在如下三个问题:
- MCMC步数仍不够多。该方法相当于每一个epoch对所有样本执行一步MCMC算法,那么训练过程中有多少epoch就会执行多少MCMC步,这个步数还是不够大;
- 数据的存储和读取。该方法中每一个样本
所对应的
需要保存下来,并不断更改。在数据集较大不能全部加载到内存中时,隐变量
的存储和读取的实现会比较困难;
- 采样的偏差可能很大。该方法使用上一个epoch中得到的隐变量作为上一个时刻的隐变量,并根据MCMC算法不断更新隐变量。但需要注意的是在这个过程中MCMC的目标分布是在不断变化的,因此这种采样方法是否等同于标准的MCMC采样还存疑。