(DataWhale)图神经网络Task06:基于图神经网络的图表征学习

背景

Weisfeiler-Lehman Test (WL Test)

Weisfeiler-Lehman图的同构性测试算法,简称WL Test,是一种用于测试两个图是否同构的算法。

  • 两个图是同构的,意思是两个图拥有一样的拓扑结构,也就是说,可以通过重新标记节点从一个图转换到另外一个图。

WL Test过程: L u h ← hash ⁡ ( L u h − 1 + ∑ v ∈ N ( U ) L v h − 1 ) L^{h}_{u} \leftarrow \operatorname{hash}\left(L^{h-1}_{u} + \sum_{v \in \mathcal{N}(U)} L^{h-1}_{v}\right) Luhhash(Luh1+vN(U)Lvh1)

  1. 迭代地聚合节点及其邻接节点的标签;
  2. 将聚合的标签散列(hash)成新标签。

在迭代过程中,发现两个图之间的节点的标签不同时,就可以确定这两个图是非同构的。(注意:节点标签可能的取值只能是有限个数)

Weisfeiler-Leman Test 算法实例

给定两个图 G G G G ′ G^{\prime} G,每个节点拥有标签(对于没有节点标签的图,使用节点的度作为标签)。

WL Test 算法通过重复执行以下给节点打标签的过程来判断图是否同构:

  1. 聚合自身与邻接节点的标签得到一串字符串,自身标签与邻接节点的标签中间用","分隔,邻接节点的标签按升序排序。
  1. 标签散列,即标签压缩,将较长的字符串映射到一个简短的标签。
image-20210602143932099
  1. 给节点重新打上标签。

当出现两个图相同节点标签的出现次数不一致时,即可判断两个图不相似。如果上述的步骤重复一定的次数后,没有发现有相同节点标签的出现次数不一致的情况,那么我们无法判断两个图是否同构。

当两个节点的 h h h层的标签一样时,表示分别以这两个节点为根节点的WL子树是一致的。WL子树与普通子树不同,WL子树包含重复的节点。下图展示了一棵以1节点为根节点高为2的WL子树。

WL Subtree Kernel:图相似性衡量

WL Test 算法的局限:只能判断两个图的相似性,无法衡量图之间的相似性。

WL Subtree Kernel使用WL Test算法不同迭代中的各类节点标签出现的次数,将其存于一个向量作为图的表征。直观地说,在WL Test的第 k k k次迭代中,一个节点的标签代表了以该节点为根的高度为 k k k的子树结构。两个图的表征向量的内积,即可作为这两个图的相似性估计,内积越大表示相似性越高。

image-20210602145242299

图同构网络

图表征学习方法简介

图表征学习要求根据节点属性、边和边的属性(如果有的话)生成一个向量作为图的表征。图表征学习主要包含以下两个过程:

  1. 首先计算得到节点表征;
  2. 其次对图上各个节点的表征做图池化(Graph Pooling),或称为图读出(Graph Readout),得到图的表征(Graph Representation)。

图同构网络基本思路

WL Test是图神经网络表达能力的上限。为了得到和WL Test一样强大的图神经网络,关键在于将图神经网络中的聚合函数设置为一个单射函数。

单射是指不同的输入值一定会得到不同的函数值。

节点表征

能实现判断图同构性的图神经网络需要满足:只在两个节点自身标签一样且它们的邻接节点一样时,图神经网络将这两个节点映射到相同的表征,即映射是单射性的。

可重复集合(Multisets)指的是元素可重复的集合,元素在集合中没有顺序关系。 一个节点的所有邻接节点是一个可重复集合,一个节点可以有重复的邻接节点,邻接节点没有顺序关系。因此GIN模型中生成节点表征的方法遵循WL Test算法更新节点标签的过程。

图池化

在生成节点的表征后仍需要执行图池化(或称为图读出)操作得到图表征,最简单的图读出操作是做求和。由于每一层的节点表征都可能是重要的,因此在图同构网络中,不同层的节点表征在求和后被拼接,其数学定义如下,
h G = CONCAT ( READOUT ( { h v ( k ) ∣ v ∈ G } ) ∣ k = 0 , 1 , ⋯   , K ) h_{G} = \text{CONCAT}(\text{READOUT}\left(\{h_{v}^{(k)}|v\in G\}\right)|k=0,1,\cdots, K) hG=CONCAT(READOUT({hv(k)vG})k=0,1,,K)
采用拼接而不是相加的原因在于不同层节点的表征属于不同的特征空间,这样得到的图的表示与WL Subtree Kernel得到的图的表征是等价的。

图同构网络的实现

图同构卷积层(GINConv)

在可重复集合上,求和函数是一个单射函数,因此只需将聚合函数改为求和函数,即可提升图神经网络的表达能力。图同构卷积层的数学定义如下:
x i ′ = h Θ ( ( 1 + ϵ ) ⋅ x i + ∑ j ∈ N ( i ) x j ) \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) xi=hΘ(1+ϵ)xi+jN(i)xj
通过torch_geometric.nn.GINConv可直接使用PyG定义好的图同构卷积层,然而该实现不支持存在边属性的图,对于这类图,自定义一个支持边属性的GINConv模块,该模块继承MessagePassing类,遵循“消息传递、消息聚合、消息更新”过程。实现的关键在以下两点:

  • super(GINConv, self).__init__(aggr = "add")中定义了消息聚合方式为add,传入给任一个目标节点的所有消息被求和得到aggr_out
  • forward函数中执行out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))实现消息的更新。
import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder

### GIN卷积层实现
class GINConv(MessagePassing):
    def __init__(self, emb_dim):
        super(GINConv, self).__init__(aggr = "add")
        self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
        self.eps = nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)
        
    def update(self, aggr_out):
        return aggr_out
BondEncoderAtomEncoder
  • BondEncoder():将输入的类别型边属性转换为边表征;
  • AtomEncoder():将输入的类别型节点属性转换为节点表征(见下文节点嵌入模块中的使用)。

训练数据中节点和边的属性都为离散值,属于不同的空间,无法直接将它们融合在一起。通过嵌入,可以将节点属性和边属性分别映射到一个新的空间,在这个新的空间中,对节点和边进行信息融合。GINConvmessage()函数中的x_j + edge_attr 操作执行了节点信息和边信息的融合。

import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 

full_atom_feature_dims = get_atom_feature_dims()  # 一个链表list,存储了节点属性向量每一维可能取值的数量,即`X[i]` 可能的取值一共有`full_atom_feature_dims[i]`种情况,`X`为节点属性
full_bond_feature_dims = get_bond_feature_dims()

class AtomEncoder(torch.nn.Module):
    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()
        self.atom_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_atom_feature_dims):  # 节点属性有多少维,那么就需要有多少个嵌入函数
            emb = torch.nn.Embedding(dim, emb_dim)  # 实例化一个嵌入函数
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:,i])  # 不同属性值得到的不同嵌入向量进行相加,从而将节点的的不同属性融合在一起

        return x_embedding


class BondEncoder(torch.nn.Module):
    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()
        self.bond_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_bond_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])

        return bond_embedding   

基于图同构网络的节点嵌入模块(GINNodeEmbedding Module)

  • 输入到此节点嵌入模块的节点属性为类别型向量,首先用AtomEncoder对其做嵌入得到第0层节点表征;
  • 然后逐层GINConv计算节点表征,从第1层开始到第num_layers层,每一层节点表征的计算都以上一层的节点表征h_list[layer]、边edge_index和边的属性edge_attr为输入。
import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F

# 节点嵌入网络构建
class GINNodeEmbedding(torch.nn.Module):
    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
        super(GINNodeEmbedding, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        # List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(GINConv(emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr

        h_list = [self.atom_encoder(x)]  # 先将类别型原子属性转化为原子表征
        for layer in range(self.num_layers):
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                # 最后一层不执行relu函数
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        if self.JK == "last":  # 取最后一层输出作为节点嵌入
            node_representation = h_list[-1]
        elif self.JK == "sum":  # 对各层输出求和作为节点嵌入
            node_representation = 0
            for layer in range(self.num_layers + 1):
                node_representation += h_list[layer]

        return node_representation

基于图同构网络的图表征模块(GINGraphRepr Module)

  • 首先采用GINNodeEmbedding模块对图上每一个节点做节点嵌入,得到节点表征;
  • 然后对节点表征做图池化得到图的表征;
  • 最后用一层线性变换对图表征转换为对图的预测。
import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding

class GINGraphRepr(nn.Module):

    def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
        """GIN Graph Pooling Module
        Args:
            num_tasks: 预测类别数,决定了图表征的维度
            num_layers: GINConv层数
            emb_dim: 节点嵌入维度
            residual (bool): 是否使用残差
            drop_ratio: dropout比率
            JK: 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和
            graph_pooling: 图池化方法,可选"sum","mean","max","attention"和"set2set"

        Out:
            图表征
        """
        super(GINGraphPooling, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)

        # 图池化函数
        if graph_pooling == "sum":  # 对节点表征求和
            self.pool = global_add_pool
        elif graph_pooling == "mean":  # 对节点表征求平均
            self.pool = global_mean_pool
        elif graph_pooling == "max":  # 对一个batch中所有节点计算节点表征各个维度的最大值
            self.pool = global_max_pool
        elif graph_pooling == "attention":  # Attention对节点表征加权求和
            self.pool = GlobalAttention(gate_nn=nn.Sequential(
                nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
        elif graph_pooling == "set2set":  # 另一种Attention对节点表征加权求和
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)
        output = self.graph_pred_linear(h_graph)

        if self.training:
            return output
        else:
            # At inference time, relu is applied to output to ensure positivity
            # 因为预测目标的取值范围就在 (0, 50] 内
            return torch.clamp(output, min=0, max=50)

参考

  1. DataWhale GNN组队学习
  2. Global Pooling Layers
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值