内容简介
ODPS Graph 是基于飞天平台实现的面向迭代的图处理框架,为用户提供了类似于 Pregel 的编程接口。用户需要将问题抽象成图的表述,然后通过一些超步进行以顶点为中心的迭代更新。
对于需要迭代学习模型参数的机器学习算法来说,图计算相比 MAP/REDUCE 具有天然的优势。
这篇文章将以用户的汽车品牌分布的推断为例,说明如何利用 ODPS-GRAPH 来做复杂的变分EM 推断。
问题陈述
在汽车类目下,通过用户购买的商品的属性信息,推断用户的汽车品牌的分布
假定有 K 个汽车品牌 b1, b2, ..., b_K,通过 “适合汽车” 这个属性id,可以简单统计出每个用户 u 在这 K 个汽车品牌的购买次数.
我们能观察到的就是这个购买次数,那么如何推断出用户的汽车品牌分布。
建模
采用 Bayesian Multinomial Mixture 模型来建模观察到的计数数据,模型图如下:
超参数 alpha 是一个 K 维向量。
隐变量 Z = [ z1, z2, ..., zN]
混合比例 pi
不可见数据 U = (Z, pi)
可见数据 V = [ v1, v2, ..., vN]
模型参数 theta = (alpha,phi)
全数据为 D = (V, Z, pi)
对 zn 采用 1-of-K 编码,即 zn 是一个 0/1 的 K 维向量,如果 vn 来自成分 k,则
znk = 1
znj = 0 for other j
下面开始 EM 算法的推导过程
观察数据的似然及其变分下界(VLB)分别为
在我们的模型中,不可见数据 (Z, pi) 的联合后验分布 intractable,这种情况下一般可以进行近似推断或者基于采样的MCMC方法。 在本文中,将使用一种近似推断方法 —— 变分推断 (Variational Inference)。
根据 mean-field theory,分别求解出 pi 的变分分布:
通过公式 (10) 和 (15) 不难看出,pi 和 Z 的后验分布之间是有联系的,互相之间通过统计量的期望值进行联系,因而实际求解过程中,需要进行迭代多轮直到两者的分布保持稳定。收敛后的分布即使两者的变分分布的最优解。
当推断出这些分布之后,将通过变分下界最大化来学习出模型参数 theta = (alpha, phi),在这个模型中,参数有封闭的求解公式,如下:
综上分析,我们给出模型的推断学习的算法
Input: 观察数据 V
Output:
a). 模型参数 alpha, phi
b) Z 的分布参数 gamma
Procedue:
1. 初始化参数 alpha, phi, gamma, sigma
2. 由 (15) 式计算 gamma
3. 由 (10), (11), (17) 迭代计算出 alpha, sigma 直到收敛
4. 由 (18) 式计算出 phi
5. 判定收敛条件是否达到,如果达到,则算法结束;否则进入 step 2.
由算法返回的 gamma 是一个 N * K 的矩阵, 矩阵的每一行对应用户 n 的后验汽车品牌分布。
ODPS-GRAPH实现
在本模型学习和推断中,涉及到迭代,而这个过程能够在 ODPS-GRAPH 上非常优雅的支持。
首先,是建立 Graph。
每个观察数据 n 是一个 Vertex,该节点维护三个信息:用户的 nick 作为 VertexId,一个 K 维向量 vn 以及一个需要推断的 K 维概率向量 gamman;
顶点之间没有直接的信息计算依赖,因此图中不需要边的存在。
然后,设计 AggregatorValue 及 Aggregator
事实上 Aggregator 能收集顶点信息,并进行一些用户定义的聚合操作。在我们的实现中,需要用 Aggregator 算出 (10) 和 (18) 式中的求和值。因而在 AggregatorValue 中维护两个数据 s4gramma 和 s4phi 分别用于计算这两个和。
另外在 AggregatorValue 中需要维护模型参数 alpha, phi 及参数 sigma。
除此之外,就是要实现 Vertex 的 compute() 方法,需要做的工作就是按照 (15) 式更新该顶点的 gamma 值。
最后是要做算法结束的判定,实现 Aggregator 的 terminate 方法,比较新旧参数的差异的 L2 范数,如果小于预先指定的容许误差 epsilon 或者超步数达到最大超步,则算法终止。
将实现的class打包即可运行,运行过程是这样的
add jar /home/weidong.yin/odps/lib/zvbmm.jar -f;
add jar /home/weidong.yin/packages/commons-math3-3.3/commons-math3-3.3.jar -f;
set odps.graph.worker.num=2;
jar -libjars zvbmm.jar,commons-math3-3.3.jar -classpath /home/weidong.yin/odps/lib/zvbmm.jar:/home/weidong.yin/packages/commons-math3-3.3/commons-math3-3.3.jar com.taobao.graph.test.VBMMAS zecheng_vbmm_in zecheng_vbmm_out $K;
这里的 commons-math3-3.3.jar 包中有 Gamma 函数 和 Digamma 函数可供调用。
一个玩具例子的运行情况如下:
输入数据: 13 个用户在 3 个汽车品牌上的购买计数
+------------+------------+
| key | info |
+------------+------------+
| u1 | 1,1,0 |
| u2 | 1,0,1 |
| u3 | 1,1,0 |
| u4 | 1,0,1 |
| u5 | 0,0,1 |
| u6 | 0,6,0 |
| u7 | 0,1,0 |
| u8 | 0,0,1 |
| u9 | 1,0,0 |
| u10 | 2,0,0 |
| u11 | 1,0,0 |
| u12 | 1,0,0 |
| u13 | 1,2,1 |
+------------+------------+
收敛过程记录如下:
superstep:1 -- superdelta:182605.64601456857
superstep:3 -- superdelta:325.1760773137442
superstep:5 -- superdelta:29.823993045142903
superstep:7 -- superdelta:8.858108454843975
superstep:9 -- superdelta:3.3334592359224535
superstep:11 -- superdelta:1.2654907417770958
superstep:13 -- superdelta:0.41016235969433495
superstep:15 -- superdelta:0.1008304226893457
superstep:17 -- superdelta:0.03374567141079521
superstep:19 -- superdelta:0.007017635448616098
superstep:21 -- superdelta:0.0013885342582021832
21步内达到容许误差,收敛是比较快的。
推断结果 gamma:
+------------+------------+
| key | info |
+------------+------------+
| u13 | 0.060984323394069596,0.9240107583094663,0.01500491829646417 |
| u11 | 0.9226879885707154,0.034669555298412306,0.04264245613087234 |
| u2 | 0.5773565581849824,0.015464121746682774,0.4071793200683348 |
| u4 | 0.5773565581849824,0.015464121746682774,0.4071793200683348 |
| u8 | 0.12802749598577526,0.05813812103740773,0.813834382976817 |
| u6 | 8.671555846303735E-9,0.9999999911750542,1.533899796361806E-10 |
| u9 | 0.9226879885707154,0.034669555298412306,0.04264245613087234 |
| u10 | 0.9927099076999223,0.0022000835657728438,0.0050900087343050075 |
| u1 | 0.5205681717937852,0.4652215734656923,0.014210254740522454 |
| u12 | 0.9226879885707154,0.034669555298412306,0.04264245613087234 |
| u3 | 0.5205681717937852,0.4652215734656923,0.014210254740522454 |
| u5 | 0.12802749598577526,0.05813812103740773,0.813834382976817 |
| u7 | 0.060984323394069596,0.9240107583094662,0.01500491829646417 |
+------------+------------+
推断结果 sigma
(-0.7209447347994979,-1.1718416342165527,-1.5968869532868073)
模型参数 alpha
(345.0677795410156,220.01290893554688,144.00588989257812)
模型参数 phi
phi[0]:(0.7224296368389298,0.12706511156093872,0.15050525160013148)
phi[1]:(0.15566385389422016,0.760368639366158,0.08396750673962187)
phi[2]:(0.28444353588382226,0.021037927314571835,0.694518536801606)
对于大的数据的训练和推断,仿此过程即可。
注记
- 似然函数的下界 VLB 是非凸的函数,因此在做 EM 推断或者变分 EM 推断时存在局部极小问题,选择好的参数初始值对于得到合理的结果非常重要,需要细致的选择初始值;
- Muitinomial Mixture 模型可以对特征为计数的观察数据进行建模;
- ODPS-GRAPH 上可以实现任何 EM 相关算法的推断,其过程可仿效这篇文章的实现。