Heterogeneous Graph Transformer
论文链接:https://arxiv.org/pdf/2003.01332
代码链接:
- https://arxiv.org/pdf/1909.01315
- https://github.com/acbull/pyHGT
关键词:图神经网络、HGT
相关概念
Heterogeneous Graph 异构图
Definition : A heterogeneous graph is defined as a directed graph G = ( V , E , A , R ) G = (V, E, A, R) G=(V,E,A,R) where each node v ∈ V v \in V v∈V and each edge e ∈ E e \in E e∈E are associated with their type mapping functions τ ( v ) : V → A \tau(v) : V \rightarrow A τ(v):V→A and Φ ( e ) : E → R \Phi(e) : E \rightarrow R Φ(e):E→R, respectively
Heterogeneous Graph是一种图数据结构,用于表示现实世界中的复杂系统,其中包含多种类型的节点(entities)和多种类型的边(relationships)。与传统的同构图(Homogeneous Graph)不同,同构图中的所有节点和边都是相同类型的,异构图能够更准确地抽象和模拟具有不同实体和多样关系的系统。
元路径:在异构图中,可以通过定义元路径(meta path)来指导信息的流动,这是模拟不同类型节点之间关系的一种方式。
有哪些“异构”
- 多种节点类型:图中的节点可以代表不同的实体,例如在学术网络中,节点可以代表论文、作者、机构、会议或研究领域等。
- 多种边类型:边代表实体之间的关系,这些关系可以是多样的,例如“作者-论文”的合作关系、“论文-会议”的发表关系、“作者-机构”的隶属关系等。
- 多样的关系属性:每种类型的边可能具有不同的属性,例如合作的作者可能有不同的贡献顺序,如第一作者、第二作者等。
Graph Neural Networks
Definition : General GNN Framework:
Suppose
H
l
[
t
]
H^l[t]
Hl[t] is the node representation of node t at the (l)-th GNN layer, the update procedure from the (l-1)-th layer to the (l)-th layer is:
H
l
[
t
]
←
Aggregate
∀
s
∈
N
(
t
)
,
∀
e
∈
E
(
s
,
t
)
(
Extract
(
H
(
l
−
1
)
[
s
]
;
H
(
l
−
1
)
[
t
]
,
e
)
)
H^l[t] \leftarrow \underset{\forall s\in N(t),\ \forall e\in E(s,t)}{\text{Aggregate}} \left(\text{Extract}(H^{(l-1)}[s]; H^{(l-1)}[t], e)\right)
Hl[t]←∀s∈N(t), ∀e∈E(s,t)Aggregate(Extract(H(l−1)[s];H(l−1)[t],e))
where
N
(
t
)
N(t)
N(t) denotes all the source nodes of node t and
E
(
s
,
t
)
E(s,t)
E(s,t) denotes all the edges from node s to t.
图神经网络(GNN)中的两个最重要的操作:Extract()
和 Aggregate()
。
- Extract() 操作:
Extract()
函数是邻居信息提取器。- 它从源节点的表示 H l − 1 [ s ] H^{l−1}[s] Hl−1[s] 中提取有用信息,同时使用目标节点的表示 H l − 1 [ t ] H^{l−1}[t] Hl−1[t]和连接两个节点的边 $ e$ 作为查询。
- 这个操作的目的是捕捉节点的局部邻域信息,通常涉及到节点特征和边特征的交互。
- Aggregate() 操作:
Aggregate()
函数收集源节点的邻域信息。- 它通过聚合操作符(如均值、求和、最大值)来实现,这些操作符可以对邻居节点的信息进行汇总。
- 除了基本的聚合方法,还可以设计更复杂的池化和归一化函数来进一步处理和提炼节点的邻域信息。
Extract()
和 Aggregate()
操作是构建GNN模型时的核心组成部分,它们共同作用于节点间的信息交换和特征聚合过程。
Extract()
负责从邻居节点那里收集信息,而 Aggregate()
则负责将这些收集来的信息进行整合,以便在图神经网络的下一层中使用。这个过程允许网络学习图中节点的高级表示,这些表示可以用于各种下游任务,如节点分类、链接预测等。
研究目的
现有的GNN方法的不足:
以HAN, GTNs, HetGNN等方法为例。
(1)大多数方法需要为异质图设计元路径
(2)要不就是假定不同类别的节点/边共享相同的特征和表示空间,要不就是单独为某一类型的节点和边设计不同的不可共享的参数。这样的话不能充分捕获异质图的属性信息;
(3)大多数方法都没有考虑异质图的动态特征;
(4)不能建模Web-scale的异质图。
进行异质图神经网络的研究,目的是:
- 保留节点和边类型的有依赖关系的特征
- 捕获网络的动态信息
- 避免自定义元路径
- 并且可扩展到大规模(Web-scale)的图上
Heterogeneous Graph Transformer
给定采样的异质子图,HGT抽取出了所有有连边的点对,输入进HGT层
HGT层的目的是从源节点聚合信息,获得目标节点的上下文表示。这属于node embedding部分,为下游任务做准备。
HGT可被分解成3个部分
-
异质互注意力(Heterogeneous Mutual Attention)
-
异质消息传递(Heterogeneous Message Passing )
-
针对特定任务的聚合(Target-Specific Aggregation)。
通过堆叠L层HGT,得到整个图的节点表示 H ( L ) H^{(L)} H(L) ,然后用于端到端的训练或者输入给下游任务。
整个框架中,高度依赖于元关系—— < τ ( s ) , ϕ ( e ) , τ ( t ) > <\tau(s), \phi(e), \tau(t)> <τ(s),ϕ(e),τ(t)>来参数化权重矩阵。和现有的维每个元路径维护一个矩阵的方法相比,HGT的三元组参数可以更好地利用异质图的schema来实现参数共享。
一方面,这样的参数共享有助于利用出现频次较少的类型的边,从而实现快速的自适应和泛化。
另一方面,使用较少的参数,仍然实现了保留不同类型边的特征。
Heterogeneous Mutual Attention
此部分的作用是:计算两个相连点之间的注意力(重要性)。
The general attention-based GNNs as follows:
H
l
[
t
]
←
Aggregate
∀
s
∈
N
(
t
)
,
∀
e
∈
E
(
s
,
t
)
(
Attention
(
s
,
t
)
⋅
Message
(
s
)
)
H^l[t] \underset{\forall s\in N(t),\ \forall e\in E(s,t)}{\leftarrow \text{Aggregate}}\left(\text{Attention}(s, t) \cdot \text{Message}(s)\right)
Hl[t]∀s∈N(t), ∀e∈E(s,t)←Aggregate(Attention(s,t)⋅Message(s))
三个基础算子:
- Attention : estimate the importance of each source node
- Message : extract the message by using only the source nodes
- Aggregate : aggregate the neighborhood message by the attention weight.
Graph Attention Network(GAT) 所使用的三个算子公式如下:
Attention
GAT
(
s
,
t
)
=
Softmax
s
∈
N
(
t
)
(
a
⃗
(
W
H
(
l
−
1
)
[
t
]
∥
W
H
(
l
−
1
)
[
s
]
)
)
Message
GAT
(
s
)
=
W
H
(
l
−
1
)
[
s
]
Aggregate
GAT
(
⋅
)
=
σ
(
Mean
(
⋅
)
)
\text{Attention}_{\text{GAT}}(s, t) = \underset{s \in N(t)}{\text{Softmax}}\left(\vec{a}(W H^{(l-1)}[t] \| W H^{(l-1)}[s])\right) \\ \text{Message}_{\text{GAT}}(s) = W H^{(l-1)}[s] \\ \text{Aggregate}_{\text{GAT}}(\cdot) = \sigma\left(\text{Mean}(\cdot)\right)
AttentionGAT(s,t)=s∈N(t)Softmax(a(WH(l−1)[t]∥WH(l−1)[s]))MessageGAT(s)=WH(l−1)[s]AggregateGAT(⋅)=σ(Mean(⋅))
- 注意力机制公式: 公式定义了计算两个节点 s 和 t之间注意力分数的方法。其中:
- a ⃗ \vec{a} a是一个可学习的权重向量,用于计算节点 t和其邻居 s之间的注意力分数
- W是权重矩阵,用于将节点的特征从一层传递到下一层
- 消息传递公式:公式定义了如何从邻居节点 s 传递消息。其中:
- W是权重矩阵,用于将邻居节点 s的特征转换为新的特征表示。
- 聚合操作公式:定义如何聚合来自邻居节点的所有消息以更新目标节点 t的表示。其中:
- Mean表示对所有邻居节点的消息取平均值。
- σ是一个非线性激活函数,用于引入非线性特性。
GAT的缺陷:GAT assumes that s and t have the same feature distributions by using one weight matrix W W W. Such an assumption is usually incorrect for heterogeneous graphs, where each type of nodes can have its own feature distribution.
文章改变了注意力计算的方式,公式如下:
Attention
HGT
(
s
,
e
,
t
)
=
Softmax
∀
s
∈
N
(
t
)
(
∣
∣
i
∈
[
1
,
h
]
ATT-head
i
(
s
,
e
,
t
)
)
ATT-head
i
(
s
,
e
,
t
)
=
(
K
i
(
s
)
W
ϕ
(
e
)
ATT
Q
i
(
t
)
T
)
μ
⟨
τ
(
s
)
,
ϕ
(
e
)
,
τ
(
t
)
⟩
d
K
i
(
s
)
=
K-Linear
τ
(
s
)
i
(
H
(
l
−
1
)
[
s
]
)
Q
i
(
t
)
=
Q-Linear
τ
(
t
)
i
(
H
(
l
−
1
)
[
t
]
)
\text{Attention}_{\text{HGT}}(s, e, t) = \underset{\forall s\in N(t)}{\text{Softmax}}\left(\underset{i\in [1,h]}{||} \text{ATT-head}^i(s, e, t)\right) \\ \text{ATT-head}^i(s, e, t) = \left(K^i(s)W^{\text{ATT}}_{\phi(e)} Q^i(t)^T\right)\frac{\mu \langle \tau(s), \phi(e), \tau(t) \rangle}{\sqrt{d}} \\ K^i(s) = \text{K-Linear}^i_{\tau(s)}\left(H^{(l-1)}[s]\right) \\ Q^i(t) = \text{Q-Linear}^i_{\tau(t)}\left(H^{(l-1)}[t]\right)
AttentionHGT(s,e,t)=∀s∈N(t)Softmax(i∈[1,h]∣∣ATT-headi(s,e,t))ATT-headi(s,e,t)=(Ki(s)Wϕ(e)ATTQi(t)T)dμ⟨τ(s),ϕ(e),τ(t)⟩Ki(s)=K-Linearτ(s)i(H(l−1)[s])Qi(t)=Q-Linearτ(t)i(H(l−1)[t])
- 注意力机制公式: 计算目标节点 t与其邻居节点 s之间的注意力分数。其中:
- h 是注意力头(attention heads)的数
||
表示将所有注意力头的分数拼接起来- 注意力头公式: 定义了单个注意力头是如何计算的。其中:
- W ϕ ( e ) ATT W^{\text{ATT}}_{\phi(e)} Wϕ(e)ATT是与边 e的类型相 ϕ ( e ) \phi(e) ϕ(e)关的注意力权重矩阵。
- μ是一个用于调整注意力分数的标量(先验知识)
- d是向量的维度,用于缩放点积的计算结果
- 键向量和查询向量投影公式:定义了如何将节点 s和 t 的隐藏表示通过线性变换投影到对应的键向量和查询向量
K-Linear
函数:
K-Linear
(Key Linear)函数用于将源节点 s的特征 H l − 1 [ s ] H^{l−1}[s] Hl−1[s]投影到所谓的“键”(Key)空间。- 在注意力机制中,键向量用于与查询向量进行比较,以计算注意力分数。
K-Linear
函数通常依赖于节点的类型 τ s \tau{s} τs,这意味着不同类型的节点将使用不同的线性变换矩阵来投影其特征,从而捕捉不同类型的特定特征。Q-Linear
函数:
Q-Linear
(Query Linear)函数用于将目标节点 t的特征 H l − 1 [ t ] H^{l−1}[t] Hl−1[t]投影到“查询”(Query)空间。- 查询向量是注意力机制中的另一个关键组成部分,它与键向量相结合,确定目标节点应该如何聚合来自其邻居节点的信息。
- 类似于
K-Linear
,Q-Linear
函数也依赖于节点的类型 τ ( t ) \tau(t) τ(t),确保不同类型的目标节点可以有不同的查询表示。
Heterogeneous Message Passing
Message H G T ( s , e , t ) = ∣ ∣ i ∈ [ 1 , h ] MSG-head i ( s , e , t ) MSG-head i ( s , e , t ) = M-Linear τ ( s ) i ( H ( l − 1 ) [ s ] ) W ϕ ( e ) MSG \text{Message}_{HGT}(s, e, t) = \underset{i\in [1,h]}{||}\text{MSG-head}^i(s, e, t) \\ \text{MSG-head}^i(s, e, t) = \text{M-Linear}^i_{\tau(s)}\left(H^{(l-1)}[s]\right)W^{\text{MSG}}_{\phi(e)} MessageHGT(s,e,t)=i∈[1,h]∣∣MSG-headi(s,e,t)MSG-headi(s,e,t)=M-Linearτ(s)i(H(l−1)[s])Wϕ(e)MSG
每个节点的信息通过一个线性层变成 d h \frac dh hd维度,右乘与边的类型相关的矩阵得到单个MSG-head,通过h个head拼接得到完整的信息。
Target-Specific Aggregation
H ˋ ( l ) [ t ] = ∑ s ∈ N ( t ) ( Attention HGT ( s , e , t ) ⋅ Message HGT ( s , e , t ) ) H ( l ) [ t ] = A-Linear τ ( t ) ( σ H ˋ ( l ) [ t ] ) + H ( l − 1 ) [ t ] \grave H^{(l)}[t] = \sum_{s \in N(t)} \left(\text{Attention}_{\text{HGT}}(s, e, t) \cdot \text{Message}_{\text{HGT}}(s, e, t)\right) \\ H^{(l)}[t] = \text{A-Linear}_{\tau(t)} \left( \sigma \grave H^{(l)}[t]\right) + H^{(l-1)}[t] Hˋ(l)[t]=s∈N(t)∑(AttentionHGT(s,e,t)⋅MessageHGT(s,e,t))H(l)[t]=A-Linearτ(t)(σHˋ(l)[t])+H(l−1)[t]
残差连接:
在HGT中,残差连接通常用于将节点在前一层的表示 H ( l − 1 ) [ t ] H^{(l−1)}[t] H(l−1)[t]直接添加到通过消息传递和注意力机制计算得到的更新表示上 H ( l ) [ t ] H^{(l)}[t] H(l)[t],以此来增强模型的表达能力。这种结构使得模型能够同时考虑局部邻域信息和跨层的特征信息,从而提高对图结构的捕捉能力
Relative Temporal Encoding
通过相对时序编码处理动态异构图
给定源节点s和目标节点t,以及它们对应的时间戳T(s) , T(t)。定义相对时间间隔为 Δ T ( t , s ) = T ( t ) − T ( s ) \Delta T(t,s)=T(t)−T(s) ΔT(t,s)=T(t)−T(s),作为得到相对时间编码 R T E ( Δ T ( t , s ) ) RTE(\Delta T(t,s)) RTE(ΔT(t,s))的索引。
注意训练集不会覆盖到所有可能的时间间隔,因此RTE要具有泛化到不可见的时间和时间间隔的能力。作者采取了一组固定的正弦函数作为偏置,并使用了可微调的线性映射构成RTE:
Base ( Δ T ( t , s ) , 2 i ) = sin ( Δ T t , s 1000 0 2 i d ) Base ( Δ T ( t , s ) , 2 i + 1 ) = cos ( Δ T t , s 1000 0 2 i + 1 d ) RTE ( Δ T ( t , s ) ) = T-Linear ( Base ( Δ T t , s ) ) \text{Base}(\Delta T(t, s), 2i) = \sin\left(\frac{\Delta T_{t,s}}{10000^\frac{2i}{d}}\right) \\ \text{Base}(\Delta T(t, s), 2i + 1) = \cos\left(\frac{\Delta T_{t,s}}{10000^\frac{2i+1}{d}}\right) \\ \text{RTE}(\Delta T(t, s)) = \text{T-Linear} \left(\text{Base}(\Delta T_{t,s})\right) Base(ΔT(t,s),2i)=sin(10000d2iΔTt,s)Base(ΔT(t,s),2i+1)=cos(10000d2i+1ΔTt,s)RTE(ΔT(t,s))=T-Linear(Base(ΔTt,s))
最后,将相对于目标节点t tt的时间编码加入到源节点s ss的表示中:
H ^ ( l − 1 ) [ s ] = H ( l − 1 ) [ s ] + RTE ( Δ T ( t , s ) ) \hat H^{(l-1)}[s] = H^{(l-1)}[s] + \text{RTE}(\Delta T(t, s)) H^(l−1)[s]=H(l−1)[s]+RTE(ΔT(t,s))
H ^ ( l − 1 ) [ s ] = H ( l − 1 ) [ s ] + RTE ( Δ T ( t , s ) ) \hat H^{(l-1)}[s] = H^{(l-1)}[s] + \text{RTE}(\Delta T(t, s)) H^(l−1)[s]=H(l−1)[s]+RTE(ΔT(t,s))
参考
- https://blog.csdn.net/byn12345/article/details/105081338
- https://zhang-each.github.io/2021/09/09/reading15/#post-comment