MTGCN笔记
标题
《Multi-Track Message Passing: Tackling Oversmoothing and Oversquashing in Graph Learning via Preventing Heterophily Mixing》
摘要
- 图神经网络的进一步发展受到消息传递中固有的两个问题的阻碍,即过平滑(oversmoothing)和过度压缩(oversquashing)。
- 作者认为这些问题的根本原因在于聚合过程中由于异配性混合导致的信息丢失,其中具有不同类别的消息被混合在一起。
- 为此,作者提出了一种**多轨图卷积网络(MTGCN)**来有效解决过平滑和过度压缩问题。MTGCN的基本思想是:如果根据消息的类别将它们分开并独立传播,则可以防止异配性混合。

方法
消息轨道的定义 : 对于一个图来说,消息轨道定义为镜像了图的拓扑结构的一系列同构图,其中每个消息轨道都是专门用以传递某个类型的节点的消息。
本文提出的MTGCN可以归纳为三个关键步骤:
- **加载 **:所有节点的特征作为初始消息被加载到对应的消息轨道,属于同一类的节点应该与对应的轨道联系在一起,这通过一个节点轨道归属度矩阵 F ∈ { 0 , 1 } ∣ T ∣ × ∣ V ∣ \mathbf{F} \in \{0,1\}^{|{\mathcal{T}|\times|\mathcal{V}|}} F∈{0,1}∣T∣×∣V∣。其中每个 F T , v \mathbf{F}_{T,v} FT,v指导节点的特征是否会被加载到轨道 T T T。
- 多轨道消息传递(MTMP):初始消息通过在多次迭代中在各自的轨道中传播和聚合来更新。
- 捕获 :基于归属度矩阵 F \mathbf{F} F ,节点通过它们归属的轨道获得的消息构建他们的节点表示 Z \mathbf{Z} Z。
多轨消息传递
在MTMP中,我们将所有轨道中所有节点的消息建模为一个三阶张量 M ∈ R ∣ T ∣ × ∣ V ∣ × d M \in \mathbb{R}^{|\mathcal{T}| \times |V| \times d} M∈R∣T∣×∣V∣×d,其中 d d d 是每条消息的维度。具体来说,矩阵 M T , : , : ∈ R ∣ V ∣ × d M_{T,:,:} \in \mathbb{R}^{|V| \times d} MT,:,:∈R∣V∣×d 是 M M M 的一个切片表示轨道 T T T 中的所有消息,而向量 M T , v , : ∈ R d M_{T,v,:} \in \mathbb{R}^d MT,v,:∈Rd 表示轨道 T T T 中节点 v v v 的消息。
最初的消息由加载的对应轨道的节点特征 X X X 构建。对于每个节点 v ∈ V v \in V v∈V 和 轨道 T ∈ T T \in \mathcal{T} T∈T:
M T , v , : ( 0 ) = g ( X v , : ) if F T , v = 1 M_{T,v,:}^{(0)} = g(\mathbf{X}_{v,:}) \quad \text{if} \quad \mathbf{F}_{T,v} = 1 MT,v,:(0)=g(Xv,:)ifFT,v=1
M T , v , : ( 0 ) = 0 ⃗ if F T , v = 0 M_{T,v,:}^{(0)} = \vec{0} \quad \text{if} \quad \mathbf{F}_{T,v} = 0 MT,v,:(0)=0ifFT,v=0
加载后,多轨消息传递通过 L 层更新消息,即 M ( 0 ) → M ( L ) M^{(0)}→M^{(L)} M(0)→M(L)。 具体来说,在每个第 l l l层中,不同轨道 T ∈ T T\in \mathcal{T} T∈T中的消息被独立地传播和聚合。对于每个轨道T,消息传递定义为:
M T , : , : ( l ) = ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 ) M T , : , : ( l − 1 ) + α l M T , : , : ( 0 ) M _ { T , : , :} ^ { ( l ) } = ( \tilde { D } ^ { - 1 / 2 } \tilde { A } \tilde { D } ^ { - 1 / 2 } ) M _ { T , : , :} ^ { ( l - 1 ) } + \alpha _ { l } M _ { T , : , :} ^ { ( 0 ) } MT,:,:(l)=(D~−1/2A~D~−1/2)MT,:,:(l−1)+αlMT,:,:(0)
在消息传递中,本文使用两种策略促进更深层次的图模型:
- 将残差连接融合到消息中
- 省略了每层的可学习参数矩阵和激活函数
- 典型GCN的传播公式 H ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)}) H(l+1)=σ(D~−21A~D~−21H(l)W(l))
一旦完成L层多轨消息传递,节点在其对应轨道中获取更新的消息 M ( L ) M^{(L)} M(L)以构造节点表示 Z \mathbf{Z} Z。对于每个节点
Z v , : = ( ∑ T ∈ T F T , v ⋅ M T , v , : ( L ) ) W Z Z _ { v ,: } = ( \sum _ { T \in \mathcal{T} } F _ { T,v} \cdot M _ { T , v , : } ^ { ( L ) } ) W _ { Z } Zv,:=(T∈T∑FT,v⋅MT,v,:(L))WZ
节点-轨道归属度
正如前面分析的那样,节点-轨道隶属矩阵 F 在我们的模型中起着至关重要的作用。 我们利用点积注意力来获得 F。具体来说,节点 v 的从属关系由下式给出
F : , v = softmax ( H v , : W K ( P W Q ) T ) F _ { :, v } = \text{softmax} ( H _ { v , :} W _ { K } ( P W _ { Q } ) ^ { T } ) F:,v=softmax(Hv,:WK(PWQ)T)
- $ F _ { :, v } \in \mathbb{R}^{|\mathcal{T}|}$ 代表节点v对所有轨道的归属度
- P P P 有 ∣ T ∣ |\mathcal{T}| ∣T∣行,代表轨道原型
- 计算得到的注意力分数代表辅助表示 H H H和轨道原型 P P P之间的相似程度,即对于节点对不同轨道的归属度。
节点的辅助表示 H \mathbf{H} H由辅助模型 Φ \Phi Φ生成。 H \mathbf{H} H 在计算 F \mathbf{F} F 时的主要目的是包含节点与其自我图中的不同信息,不追求高准确度,因此即使是简单的GCN也可以在理论上获得提升。
轨道原型 P \mathbf{P} P 是使用代表性节点构建的。 具体来说,轨道T的原型定义为
P T , : = 1 Δ ∑ v ∈ B δ ( y v , T ) ⋅ H v , : P _ { T , : } = \frac { 1 } { \Delta } \sum _ { v \in \mathcal{B} } \delta ( y _ { v } , T ) \cdot H _ { v , : } PT,:=Δ1v∈B∑δ(yv,T)⋅Hv,:
- B \mathcal{B} B 是代表性节点的集合,它是由训练集中带标签的节点和更有可能被估计模型正确预测的节点(使用softmax置信分数衡量)
- δ \delta δ 函数指示节点标签是否与对应的轨道相同
是代表性节点的集合,它是由训练集中带标签的节点和更有可能被估计模型正确预测的节点(使用softmax置信分数衡量)
- δ \delta δ 函数指示节点标签是否与对应的轨道相同

2353

被折叠的 条评论
为什么被折叠?



