#论文题目:HyperSCI:Learning Causal Effects on Hypergraphs(HyperSCI:基于超图的因果学习)
#论文地址:https://arxiv.org/abs/2207.04049
#论文源码开源地址:暂无
#论文所属会议:KDD 2022
#论文所属单位:弗吉尼亚大学夏洛茨维尔分校、微软
一、前序知识
-
因果推断:在常用的机器学习算法中,关注的是特征之间的相关性,而无法去识别特征之间的因果性,但很多时候在做决策与判断的时候,我们需要的是因果性,只有这样,才能说明是否缺少A导致B的变化。具体解释详见《因果推断——简介》、《大白话谈因果系列文章》
-
超图:是一种包含多种节点的集合,一个集合中的节点相互具有关系,同一个节点可以存在在多个集合中。(我本科毕业论文就是基于超图来研究的)
关于超图的形象解释引用论文中的图示如下(a):
-
一阶干涉和高阶干涉:
- 一阶干涉:如图所示,在 u u u1、 u u u2、 u u u3 都参与的聚会(gathering event)中即它们在同一条超边(hyperedge)内,个体 u u u1的感染结果(outcome)会受到其一阶个体 u u u2、 u u u3的影响。我们将 u u u2—> u u u1、 u u u3—> u u u1 , 这种影响称之为一阶干涉(first-order interference)。
- 高阶干涉:个体 u u u1的感染结果(outcome)还会受到其他个体之间相互作用的影响,即 u u u2和 u u u3间的相互作用也可能影响到病毒对 u u u1的暴露程度。我们将 u u u2 × u u u3—> u u u1这种二阶(second-order)互动效应的影响称之为高阶干涉(high-order interference)。
二、导读/创新点
- 为什么用超图结构而不是传统结构来构建网络:传统图(pairwise graph)的定义涵盖了大部分的应用场景(例如人与人之间的物理接触或社交网络),但它不能捕捉到群体互动(group interaction)的信息(即每个互动会涉及两个以上的人)。
- 为什么采用因果推断,而不是传统的基于统计学概率:因为现在的工作大多数是在统计相关性(statistical correlation)的角度进行研究,比如:通过捕捉一个人(individual)的人口信息(即节点特征 node features)、团体聚会史(超图结构 hypergraph structure)和感染结果(节点标签 node labels)之间的相关性来预测每个人(节点 node)的 COVID-19 感染风险,但是缺乏因果性推断,而因果性推断对于了解政策干预(如:戴口罩)对结果(如:感染 COVID-19)的影响尤为重要。形象化形容:个人是否戴口罩(实验 treatment)会如何在因果关系上影响其感染风险(结果 outcome)?
- 在计算超图中的高阶干涉的时候引入注意力机制对超图中的节点权重进行修正,以达到不同节点有不同权重的目的。
所以,这篇文章想要探索在超图上进行因果推断任务的学习。具体来说,这篇文章专注于从观察数据中估计在超图干涉(hypergraph interference)下的个体实验效果(individual treatment effect ITE)。
三、HyperSCI框架
简单来说,HyperSCI 控制了混杂因素(confounder),在表征学习(representation learning)的基础上建立了高阶干涉(high-order interference)模型,最后根据学习到的表征做出估计(estimation)。
可以看出,给定某个超图,算法分为三部分:混杂因子表征学习、高阶干涉模型表征学习、结果预测。
-
Confounder Representation Learning(混杂因子表征学习):对于初始化的节点向量 x x xi,我们先通过MLP将节点向量表征 x x xi编码到一个隐空间,即 z z zi = M L P MLP MLP( x x xi),我们得到了一组表征
· Z Z Z可以捕捉到所有潜在的混杂因素(confounder),所以模型可以通过控制 z z zi来减轻混杂偏倚(confounding bias)。我们将 Z Z Z称为混杂因素表征(confounder representation)。
· 并且由于混杂因素表征 的分布在对照组(control group)和实验组(treatment group)可能存在差异,在损失函数中加入差异惩罚项(discrepancy penalty)来平衡表征 Z Z Z。差异惩罚项可以用任何计算两个分布间距离指标来计算。本文采用 Wasserstein-1 distance 作为计算对照组和实验组间表征分布的差异惩罚项。 -
Interference Modeling(高阶干涉模型表征学习):
此部分分为两个阶段:- 超图卷积网络:我们计算
P
P
P(第L层的卷积表征因子),进而进行迭代训练,直到达到收敛时,我们得到
P
P
P(l+1)为最后的各节点的表征集合。
- 计算超图中各个节点所对应的超边的注意力(通过归一化方式):我们给每条超边 e e e计算其表征 z z ze。该表征是通过聚合其有关联的节点 N N Ne得到,即: z z ze = A G G AGG AGG({ z z zi | i i i属于 N N Ne}),其中 A G G AGG AGG可以是任何聚合函数。(详细公式请见原文)。
得到注意力之后,我们用其模拟不同程度的干涉,使用增强矩阵 H ⃗ \vec{H} H来代替上式中的原始关联矩阵 H H H。
这样一来,在同一超边内不同节点的干涉就可以被赋予不同的权重,来表示对建模干涉不同程度的贡献。我们将卷积层最后一层的表征定义为,并希望它能捕捉到每个节点的高阶干涉。
- 超图卷积网络:我们计算
P
P
P(第L层的卷积表征因子),进而进行迭代训练,直到达到收敛时,我们得到
P
P
P(l+1)为最后的各节点的表征集合。
-
Outcome Prediction(结果预测):在得到混杂因素表征 z z zi和干涉表征 p p pi后,我们对潜在结果(potential outcome)进行建模:
最终,每个实例i的ITE可以被估计为:
损失函数等其他具体公式详见原文
四、结论
文章提出了一个基于超图的因果学习模型,具体通过聚合混杂因子(利用MLP)和高阶干涉(利用注意力机制修正超边中的节点信息、利用多层图卷积网络进行迭代训练)的两部分节点向量表征进而得到 y y y^1i和 y y y^2i。