Task02 消息传递图神经网络

Task02 消息传递图神经网络

一、消息传递范式基本概念

  • 消息传递范式定义
    • 基于神经网络的生成节点表征的范式
    • 是一种聚合邻接节点信息来更新中心节点信息的范式
  • 消息传递范式的三个步骤
    • (1)邻接节点信息变换
    • (2)邻接节点信息聚合到中心节点
    • (3)聚合信息变换
  • 消息传递范式描述

x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i))

  • 节点嵌入:指代神经网络生成节点表征的操作

二、MessagePassing基类

Pytorch Geometric(PyG)提供了MessagePassing基类,它实现了消息传播的自动处理,继承该基类可使我们方便地构造消息传递图神经网络,我们只需定义函数 ϕ \phi ϕ,即message()函数,和函数 γ \gamma γ,即update()函数,以及使用的消息聚合方案,即aggr="add"aggr="mean"aggr="max"

三、继承MessagePassing类的GCNConv

  • GCNConv的数学定义为

    x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right) xi(k)=jN(i){i}deg(i) deg(j) 1(Θxj(k1))

  • 此公式的步骤:

    • 向邻接矩阵添加自环
    • 线性转换节点特征矩阵
    • 计算归一化系数
    • 归一化 j j j中的节点特征
    • 将相邻节点特征相加("求和 "聚合)

四、propagate函数

def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
    r"""开始消息传播的初始调用。
    Args:
        edge_index (Tensor or SparseTensor): 定义了消息传播流。
        	当flow="source_to_target"时,节点`edge_index[0]`的信息将被发送到节点`edge_index[1]`,
        	反之当flow="target_to_source"时,节点`edge_index[1]`的信息将被发送到节点`edge_index[0]`
        kwargs: 图其他属性或额外的数据。
    """

五、覆写message函数

class GCNConv(MessagePassing):
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)
    def message(self, x_j, norm, d_i):
        # x_j has shape [E, out_channels]
        return norm.view(-1, 1) * x_j * d_i # 这里不管正确性     

六、覆写aggregate函数

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)

    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
        

七、覆写message_and_aggregate函数

from torch_sparse import SparseTensor

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
        # 此处传的不再是edge_idex,而是SparseTensor类型的Adjancency Matrix
        return self.propagate(adjmat, x=x, norm=norm, d=d)
    
    def message(self, x_j, norm, d_i):
        # x_j has shape [E, out_channels]
        return norm.view(-1, 1) * x_j * d_i # 这里不管正确性
    
    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
    
    def message_and_aggregate(self, adj_t, x, norm):
        print('`message_and_aggregate` is called')

九、覆写update函数

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')

    def update(self, inputs: Tensor) -> Tensor:
        return inputs

十、作业

作业描述:

  1. 请总结MessagePassing类的运行流程以及继承MessagePassing类的规范。
  2. 请继承MessagePassing类来自定义以下的图神经网络类,并进行测试:
    1. 第一个类,覆写message函数,要求该函数接收消息传递源节点属性x、目标节点度d
    2. 第二个类,在第一个类的基础上,再覆写aggregate函数,要求不能调用super类的aggregate函数,并且不能直接复制super类的aggregate函数内容。
    3. 第三个类,在第二个类的基础上,再覆写update函数,要求对节点信息做一层线性变换。
    4. 第四个类,在第三个类的基础上,再覆写message_and_aggregate函数,要求在这一个函数中实现前面message函数和aggregate函数的功能。

题解:

题解1.MessagePassing基类的运行流程:

  1. 初始化参数聚合函数aggr,消息传递流向flow,传播维度node_dim

  2. 初始化自实现函数中用到的自定义参数__user_args____fused_user_args__

  3. 基于Module基类,调用forward函数,用于数据或参数的初始化

  4. propagate函数:
    (1)检查edge_indexsize参数是否符合要求,并返回size
    (2)判断edge_index是否为SparseTensor,如果满足,则执行message_and_aggregate,再执行update方法
    (3)如果不满足,就先执行message方法,再执行aggregateupdate方法

题解2.

import torch
from torch.nn import functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.datasets import Planetoid


class MyGNN(MessagePassing):
    """
    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{x}_i \cdot \mathbf{\Theta}_1 +
        \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot
        (\mathbf{\Theta}_2 \mathbf{x}_i - \mathbf{\Theta}_3 \mathbf{x}_j)
    """

    def __init__(self, in_channels, out_channels, device):
        super(MyGNN, self).__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.lin1 = torch.nn.Linear(in_channels, out_channels).to(device)
        self.lin2 = torch.nn.Linear(in_channels, out_channels).to(device)
        self.lin3 = torch.nn.Linear(in_channels, out_channels).to(device)

    def forward(self, x, edge_index):
        a = self.lin1(x)
        b = self.lin2(x)
        out = self.propagate(edge_index, a=a, b=b)
        return self.lin3(x) + out

    def message(self, a_i, b_j):
        out = a_i - b_j
        return out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)
device = torch.device('cuda:0')

dataset = Planetoid(root='dataset/Cora', name='Cora')
model = MyGNN(in_channels=dataset.num_features, out_channels=dataset.num_classes, device=device)
print(model)

data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index).to(device)
    pred = out.argmax(dim=1)
    accuracy = int((pred[data.test_mask] == data.y[data.test_mask]).sum()) / data.test_mask.sum()
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print("Train Epoch: {:3} Accuracy: {:.2f}%".format(epoch, accuracy.item() * 100.0))

参考资料

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值