【2019/ICML】DAG-GNN: DAG Structure Learning with Graph Neural Networks


原文链接:https://dreamhomes.github.io/posts/202101041501.html


文章链接:https://arxiv.org/abs/1904.10098
源码链接:https://github.com/fishmoon1234/DAG-GNN

TL;DR

论文中提出一种新的DAG编码架构 DAG-GNN,其实模型的本质就是一个图变分自编码器,模型的优点是既能处理连续型变量又能处理离散型变量;在人工数据集和真实数据集中验证了模型结果可以达到全局最优 🤔;

Model / Algorithm

论文中的整体模型架构如下:

Linear Structural Equation Model

论文中首先通过生成模型来泛化线性结构等价模型;假设 A ∈ R m × m A \in \mathbb{R}^{m \times m} ARm×m 表示DAG的加权邻接矩阵, X ∈ R m × d X \in \mathbb{R}^{m \times d} XRm×d 表示每个节点的特征,那么线性模型的的编码方式为:
X = A T X + Z ( 1 ) X=A^{T} X+Z \quad\quad\quad(1) X=ATX+Z(1)
其中 Z ∈ R m × d Z \in \mathbb{R}^{m \times d} ZRm×d 表示噪声矩阵;如果图中节点是以拓扑序排列的,那么矩阵 A A A 是一个严格的上三角矩阵,因此DAG中的 ancestral sampling 等价于三角等式的解:
X = ( I − A T ) − 1 Z ( 2 ) X=\left(I-A^{T}\right)^{-1} Z \quad\quad\quad(2) X=(IAT)1Z(2)

Proposed Graph Neural Network Model

上述等式 (2) 可以写为 X = f A ( Z ) X=f_A(Z) X=fA(Z),可以表示为数据节点特征 Z Z Z 并得到embedding X X X。传统的GCN 架构计算公式如下:
X = A ^ ⋅ ReLU ⁡ ( A ^ Z W 1 ) ⋅ W 2 X=\widehat{A} \cdot \operatorname{ReLU}\left(\widehat{A} Z W^{1}\right) \cdot W^{2} X=A ReLU(A ZW1)W2
由于公式 (2) 的特殊结构,因此提出新的图神经网络架构,注意这是解码器的结构
X = f 2 ( ( I − A T ) − 1 f 1 ( Z ) ) ( 3 ) X=f_{2}\left(\left(I-A^{T}\right)^{-1} f_{1}(Z)\right)\quad\quad\quad(3) X=f2((IAT)1f1(Z))(3)

其中 f 1 , f 2 f_1, f_2 f1,f2 表示 Z , X Z, X Z,X 的非线性的转换函数;

Model Learning with Variational Autoencoder

对于给定的分布 Z Z Z 和样本 X 1 , ⋯   , X n X^1, \cdots, X^n X1,,Xn,生成模型的目标是最大化对数函数:
1 n ∑ k = 1 n log ⁡ p ( X k ) = 1 n ∑ k = 1 n log ⁡ ∫ p ( X k ∣ Z ) p ( Z ) d Z \frac{1}{n} \sum_{k=1}^{n} \log p\left(X^{k}\right)=\frac{1}{n} \sum_{k=1}^{n} \log \int p\left(X^{k} \mid Z\right) p(Z) d Z n1k=1nlogp(Xk)=n1k=1nlogp(XkZ)p(Z)dZ
由于上式难以解决因此使用变分贝叶斯;

使用变分后验概率 q ( Z ∣ X ) q(Z|X) q(ZX) 来近似实际后验概率 q ( Z ∣ X ) q(Z|X) q(ZX)。网络优化的结果是 ELBO(the evidence lower bound)
L E L B O = 1 n ∑ k = 1 n L E L B O k L_{\mathrm{ELBO}}=\frac{1}{n} \sum_{k=1}^{n} L_{\mathrm{ELBO}}^{k} LELBO=n1k=1nLELBOk
其中
L E L B O k ≡ − D K L ( q ( Z ∣ X k ) ∥ p ( Z ) ) + E q ( Z ∣ X k ) [ log ⁡ p ( X k ∣ Z ) ] \begin{array}{r} L_{\mathrm{ELBO}}^{k} \equiv-D_{\mathrm{KL}}\left(q\left(Z \mid X^{k}\right) \| p(Z)\right) \\ \quad+\mathrm{E}_{q\left(Z \mid X^{k}\right)}\left[\log p\left(X^{k} \mid Z\right)\right] \end{array} LELBOkDKL(q(ZXk)p(Z))+Eq(ZXk)[logp(XkZ)]

基于 (3)式的解码器结构,对应的编码器结构为
Z = f 4 ( ( I − A T ) f 3 ( X ) ) ( 5 ) Z=f_{4}\left(\left(I-A^{T}\right) f_{3}(X)\right) \quad\quad\quad(5) Z=f4((IAT)f3(X))(5)
其中 f 4 , f 3 f_4, f_3 f4,f3 表示 f 2 , f 1 f_2,f_1 f2,f1 的逆函数。

Loss Function

对于编码器,使用MLP表示 f 3 f_3 f3和恒等映射表示 f 4 f_4 f4,变分后验概率 q ( Z ∣ X ) q(Z|X) q(ZX) 是一个因子高斯分布均值 M Z ∈ R m × d M_Z\in \mathbb{R}^{m\times d} MZRm×d 标准差 S Z ∈ R m × d S_Z\in \mathbb{R}^{m\times d} SZRm×d,可以通过编码器来进行计算:
[ M Z ∣ log ⁡ S Z ] = ( I − A T ) MLP ⁡ ( X , W 1 , W 2 ) ( 6 ) \left[M_{Z} \mid \log S_{Z}\right]=\left(I-A^{T}\right) \operatorname{MLP}\left(X, W^{1}, W^{2}\right)\quad\quad\quad(6) [MZlogSZ]=(IAT)MLP(X,W1,W2)(6)
其中 MLP ⁡ ( X , W 1 , W 2 ) : = ReLU ⁡ ( X W 1 ) W 2 \operatorname{MLP}\left(X, W^{1}, W^{2}\right):=\operatorname{ReLU}\left(X W^{1}\right) W^{2} MLP(X,W1,W2):=ReLU(XW1)W2

对于生成模型,使用恒等映射表示 f 1 f_1 f1 MLP来表示 f 2 f_2 f2,得到的似然 p ( X ∣ Z ) p(X | Z) p(XZ) 符合高斯分布均值为 M X ∈ R m × d M_X\in \mathbb{R}^{m\times d} MXRm×d 标准差为 S X ∈ R m × d S_X\in \mathbb{R}^{m\times d} SXRm×d,解码器的计算公式如下:
[ M X ∣ log ⁡ S X ] = MLP ⁡ ( ( I − A T ) − 1 Z , W 3 , W 4 ) ( 7 ) \left[M_{X} \mid \log S_{X}\right]=\operatorname{MLP}\left(\left(I-A^{T}\right)^{-1} Z, W^{3}, W^{4}\right)\quad\quad\quad(7) [MXlogSX]=MLP((IAT)1Z,W3,W4)(7)

基于公式(6)(7),式(4)中的KL散度项为:
D K L ( q ( Z ∣ X ) ∥ p ( Z ) ) = 1 2 ∑ i = 1 m ∑ j = 1 d ( S Z ) i j 2 + ( M Z ) i j 2 − 2 log ⁡ ( S Z ) i j − 1 \begin{array}{l} D_{\mathrm{KL}}(q(Z \mid X) \| p(Z))= \\\\ \quad \frac{1}{2} \sum_{i=1}^{m} \sum_{j=1}^{d}\left(S_{Z}\right)_{i j}^{2}+\left(M_{Z}\right)_{i j}^{2}-2 \log \left(S_{Z}\right)_{i j}-1 \end{array} DKL(q(ZX)p(Z))=21i=1mj=1d(SZ)ij2+(MZ)ij22log(SZ)ij1
重构准确率项为:
E q ( Z ∣ X ) [ log ⁡ p ( X ∣ Z ) ] ≈ 1 L ∑ l = 1 L ∑ i = 1 m ∑ j = 1 d − ( X i j − ( M X ( l ) ) i j ) 2 2 ( S X ( l ) ) i j 2 − log ⁡ ( S X ( l ) ) i j − c \begin{array}{c} \mathrm{E}_{q(Z \mid X)}[\log p(X \mid Z)] \approx \\\\ \frac{1}{L} \sum_{l=1}^{L} \sum_{i=1}^{m} \sum_{j=1}^{d}-\frac{\left(X_{i j}-\left(M_{X}^{(l)}\right)_{i j}\right)^{2}}{2\left(S_{X}^{(l)}\right)_{i j}^{2}}-\log \left(S_{X}^{(l)}\right)_{i j}-c \end{array} Eq(ZX)[logp(XZ)]L1l=1Li=1mj=1d2(SX(l))ij2(Xij(MX(l))ij)2log(SX(l))ijc
对于不同类型变量的处理论文中使用了不同的结构,详细参考原文推导过程。

Experiments

人工数据集

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值