【GNN】第三章 消息传递范式与PyG的MessagePassing基类

本文参考自datawhale2021.6学习:图神经网络

【GNN】第一章 图论基础

【GNN】第二章 PyG中的图与图数据集

1 消息传递范式

消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接

消息传递神经网络(MPNN)是一种框架,其前向传递有两个阶段:消息传递阶段(Message Passing)、读出阶段(Readout),这里先介绍消息传递阶段

1.1 消息传递的三个函数

  • 三个函数分为:
    • 各边要传递的消息的创建 ϕ \phi ϕ、消息聚合 □ \square 、节点表征的更新 γ \gamma γ 三个步骤
  • 对三个函数的要求:
    • 要求上述三个函数均可微
    • 且消息聚合具有排列不变性(函数输出结果与输入参数的排列无关,即对节点的排列不敏感)。
    • 具有排列不变性的函数有和函数、均值函数和最大值函数。
  • 消息传递的数学描述:
    • x i ( k − 1 ) ∈ R F \mathbf{x}^{(k-1)}_i\in\mathbb{R}^F xi(k1)RF表示 ( k − 1 ) (k-1) (k1)层中节点 i i i的节点属性, e j , i ∈ R D \mathbf{e}_{j,i} \in \mathbb{R}^D ej,iRD 表示从节点 j j j到节点 i i i的边的属性,消息传递可以描述为:
      x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i))

1.2 节点嵌入与节点表征

  • 节点嵌入(Node Embedding):神经网络生成节点表征的操作,或节点表征也称节点嵌入
  • 这里节点嵌入仅指代前者
  • 好的节点表征可以衡量节点间的相似性,需要通过图神经网络训练得到

2 PyG的MessagePassing基类

2.1 属性

源码

class MessagePassing(torch.nn.Module):
	def __init__(self, aggr: Optional[str] = "add", flow: str = "source_to_target", node_dim: int = -2):
        super(MessagePassing, self).__init__()
        
        # 此处省略n行代码
        
        self.aggr = aggr
        assert self.aggr in ['add', 'mean', 'max',None]

        self.flow = flow
        assert self.flow in ['source_to_target', 'target_to_source']

        self.node_dim = node_dim
        
        self.fuse = self.inspector.implements('message_and_aggregate')
        
        # 此处省略n行代码
  • aggr:定义要使用的聚合方案,默认add
  • flow:定义消息传递的流向,从而确定给某节点传递消息的边的集合,默认s→t
    • i i i 表示目标节点, j j j 表示邻接节点
    • flow=‘source_to_target’ :target表入,即传递信息的边的集合为 ( j , i ) ∈ E (j,i)\in\mathcal{E} (j,i)E
    • flow=‘target_to_source’ :target表出,即传递信息的边的集合为 ( i , j ) ∈ E (i,j)\in\mathcal{E} (i,j)E
  • node_dim:定义scatter沿着哪个轴线传播,默认-2
  • fuse:检查是否实现了message_and_aggregate()方法,不需要自己定义

2.2 方法

class MessagePassing(torch.nn.Module):
    # 此处省略n行代码
    	self.fuse = self.inspector.implements('message_and_aggregate')
    	self.node_dim = node_dim
    
    def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
    	# 此处省略n行代码
    	
        # 检查edge_index是否SparseTensor类型
        # 检查是否实现了message_and_aggregate()方法,是就执行该方法,再执行update方法
        if (isinstance(edge_index, SparseTensor) and self.fuse and not self.__explain__):
            coll_dict = self.__collect__(self.__fused_user_args__, edge_index, size, kwargs)
            
            # message_and_aggregate
            msg_aggr_kwargs = self.inspector.distribute('message_and_aggregate', coll_dict)
            out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
            
			# update
            update_kwargs = self.inspector.distribute('update', coll_dict)
            return self.update(out, **update_kwargs)

        # 上述检查不通过,依次执行message(),aggregate(),update()方法
        elif isinstance(edge_index, Tensor) or not self.fuse:
        
            coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs)
            
            # message
            msg_kwargs = self.inspector.distribute('message', coll_dict)
            out = self.message(**msg_kwargs)
            
            # aggregate
            aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
            out = self.aggregate(out, **aggr_kwargs)

 			# update
            update_kwargs = self.inspector.distribute('update', coll_dict)
            return self.update(out, **update_kwargs)


    def message(self, x_j): 
    	# 按需要覆写或不写
        return x_j


    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor:
        
        # 按需要覆写或不写
                  
        if ptr is not None:
            ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
            return segment_csr(inputs, ptr, reduce=self.aggr)
            
        else:
            return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)


    def update(self, inputs): 
        # 按需要覆写或不写
        return inputs
    
    def message_and_aggregate(self, adj_t, x, norm):
        # 按需要覆写或不写
        return x
  • propagate(edge_index, size=None, **kwargs)
    • 调用以传递消息,在此方法中messageaggregateupdate等方法被调用
    • 若检测到message_and_aggregate和edge_index为SparseTensor,则即使messageaggregate存在也不调用,而是调用message_and_aggregate
    • 可将节点属性拆分成中心节点和邻接节点,对拆分的数据有格式要求,必须为 [num_nodes, *]。拆分如属性 x i x_i xi 和邻接节点属性 x j x_j xj,度 d e g i 、 d e g j deg_i、deg_j degidegj
    • size=None默认邻接矩阵对称,若是非对称的邻接矩阵(如二部图)则要传递参数 size=(N,M)
  • message (写入需要的参数…)
    • 实现 ϕ \phi ϕ 函数,创建各边要传递的邻接节点消息
    • 可以接收传递给propagate方法的任何参数,只要在其中进行定义。如def message(self,x_j)而非 def message(self,x_j=x_j)
  • aggregate (inputs, …)
    • 实现消息聚合
    • 关于scatter(src,index,dim=-1,out,dim_size,reduce=‘sum’):按照dim的操作方向, 将src的元素加到index指示的位置去。参考torch_scatter.scattertorch_scatter.scatter 区别scatter_
    • propagate调用时,传入给inputs的是message的输出
  • message_and_aggregate (写入需要的参数…)
    • 一些场景里 ϕ \phi ϕ 和聚合可以融合在一起操作,就可以在该方法里定义这两项操作,使程序运行更加高效
  • update (inputs, …)
    • 节点表征的更新,可以接收传递给propagate方法的任何参数
    • propagate 调用时inputs输入的是aggregate的输出

3 GCNConv的实现

3.1 数学定义

x i k = ∑ j ∈ N ( i ) ∪ { i } 1 d ( v i ) ⋅ d ( v j ) ⋅ ( Θ ⋅ x i k − 1 ) x_i^k = \sum_{j_\in \mathcal{N}(i) \cup \{i\}}\frac{1}{\sqrt{d(v_i)}\cdot \sqrt{d(v_j)}}\cdot \left( \Theta\cdot x_i^{k-1}\right) xik=jN(i){i}d(vi) d(vj) 1(Θxik1)

X k = D ^ − 1 2 A ^ D ^ − 1 2 Θ X k − 1 X^k = \hat D^{-\frac{1}{2}}\hat A\hat D^{-\frac{1}{2}}\Theta X^{k-1} Xk=D^21A^D^21ΘXk1
A ^ = A + I \hat A=A+I A^=A+I 加入了自循环的邻接矩阵, D ^ \hat D D^ 是由 A ^ \hat A A^ 计算的度矩阵
在这里插入图片描述
矩阵A行表出,列表入。左矩阵D对对应的出节点×,右矩阵D对对应的入节点×

3.2 代码实现

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):

    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')

        # 线性变换层 Θ
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
    
        # x 形状 [N, in_channels]
        # edge_index 形状 [2, E]

        # 添加自环的边
        # edge_index形状为 [2,E+N]
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # 节点属性做线性变换
        x = self.lin(x)

        # Compute normalization.
        row, col = edge_index   # row从节点0开始一直顺序排 e.g.[0,0,0,1,1…]
        deg = degree(col, x.size(0), dtype=x.dtype)   # 计算度矩阵
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # 若要将edge_index改写为SparseTensor
        # adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))

        # 调用propagate传递信息
        return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1))) # 若一个数据可以被拆分成属于中心节点的部分和属于邻接节点的部分,其形状必须是 [num_nodes, *],所以需要将deg的形状进行变换
        # return self.propagate(edge_index, x=x, norm=norm)
        # return self.propagate(adjmat, x=x, norm=norm, deg=deg.view((-1, 1)))

    # 覆写消息构建函数 Φ
    def message(self, x_j, norm, deg_i):
        # x_j 是邻接节点矩阵,形状为 [E+N, out_channels]
        # 这里flow = 'source_to_target',因此x_j行排序如row
        # deg_i 是col排序的点的度
        return norm.view(-1, 1) * x_j   # 将每个邻接节点正则化,返回形状同 x_j


    # 不需要覆写aggregate和update
    # 这里未实现message_and_aggregate

    # 也可以覆写aggregate,举个例子
    def aggregate(self, inputs, index, ptr, dim_size):
        return  super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size) 
        # index是中心节点,根据flow在此其排序同col
        # dim_size = 节点数

	# 覆写函数时,传入的参数不要写 y=y 这种格式


# 调用网络
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]

net = GCNConv(data.num_features, 64) # 类属性的定义
h_nodes = net(data.x, data.edge_index) # 调用forward,输入参数
print(h_nodes.shape)

无注释代码

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):

        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j 

    def aggregate(self, inputs, index, ptr, dim_size):
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)


from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]

net = GCNConv(data.num_features, 64) # 类属性的定义
h_nodes = net(data.x, data.edge_index) # 调用forward,输入参数
print(h_nodes.shape)

4 作业

  1. 请总结MessagePassing基类的运行流程。
  • 首先创建每条边上要传递的邻接节点的信息
  • 其次对中心节点接收到的消息进行聚合
  • 最后更新节点表征
  1. 请复现一个一层的图神经网络的构造,总结通过继承MessagePassing基类来构造自己的图神经网络类的规范。
  • GNN规范:
    • 属性如神经网络层、继承flowaggr等属性
    • 定义forward方法,传入节点矩阵x与边edge_index
      • 添加自环边
      • 节点属性变换
      • 建立度矩阵,计算正则公式
      • (上述两步也可在message中完成)
      • 调用propagate,传入参数edge_index、方法message、aggregate、update要用到的参数如normxdeg
      • 返回最终update后的节点表征
    • 覆写有关函数
      • message传入邻接节点信息x_j,正则化公式norm
  • 复现一个一层图神经网络
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):

        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j 
    
    def update(self,aggr_output):
        return F.relu(aggr_output)
  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值