python实现两种GNN代码计算范式

概要

在之前的博文里(传送门)提到过图神经网络有两种实现的计算范式,一种是矩阵计算范式,另一种是节点计算范式,本文从两种视角出发描述其python编程方法演示。

节点计算范式整体架构流程

利用节点计算范式图神经网络本质上就是按照顺序挨个计算每个节点的邻居汇聚,同时这也是nvidia提供的cuda计算范式,简而言之就是计算该公式:
顶点计算
该方法在分布式中支持跨子图的节点通信

节点计算范式技术细节

核心是提前对输入的邻接矩阵做稀疏化处理

from torch_geometric.utils import dense_to_sparse
edge_src_target,edge_weight = dense_to_sparse(adj) #adj代表邻接矩阵

获得稀疏表示的边与变权重以后,再进行计算

import torch
import torch.nn as nn
class GraphGNN(nn.Module):
    def __init__(self, device, in_dim, out_dim):
        super(GraphGNN, self).__init__()
        self.device = device
        self.w = Parameter(torch.rand([1])) # w和b
        self.b = Parameter(torch.rand([1]))
        e_h = 32 # edge_mlp hidden
        e_out = 30 # edge_mlp out
        n_out = out_dim # node_mlp out
        self.edge_mlp = Sequential(Linear(in_dim * 2 + 1, e_h), 
                                   Sigmoid(),
                                   Linear(e_h, e_out),
                                   Sigmoid(),
                                   )
        self.node_mlp = Sequential(Linear(e_out, n_out),
                                   Sigmoid(),
                                   )

    def forward(self, x, edge_src_target, edge_weight): # (batch_size=32,station_num=2160,attr_num=3)
        edge_src_target = edge_src_target.to(self.device) # 节点索引 class中的参数都传入设备 (2,edges_num)
        edge_weight = edge_weight.to(self.device)
        self.w = self.w.to(self.device)
        self.b = self.b.to(self.device)

        edge_src, edge_target = edge_src_target # {2,edge_num} -> src {1,edges_num} 和 target {1,edges_num}
        node_src = x[:, edge_src] # {batch_size,station_num,feature_num} -> {batch_size,edges_num,feature_num}
        node_target = x[:, edge_target] # {batch_size,station_num,3} -> {batch_size,edges_num,feature_num}

        edge_w = edge_weight.unsqueeze(-1)
        edge_w = edge_w[None, :, :].repeat(node_src.size(0), 1, 1).to(self.device) # (edges_num,1) -> (32,edges_num,1)
        out = torch.cat([node_src, node_target, edge_w], dim=-1) #在最后一个维度进行累加 -> (32,edges_num,3+3+1=7)
        out = self.edge_mlp(out) # out传入edge_mlp更新边属性(32,edges_num,30) e_h = 30

        # 汇聚入度的边特征 and 刨除出度的边特征 最后得到本节点的特征
        out_add = scatter_add(out, edge_target, dim=1, dim_size=x.size(1))
        # out_sub = scatter_sub(out, edge_src, dim=1, dim_size=x.size(1))
        out_sub = scatter_add(out.neg(), edge_src, dim=1, dim_size=x.size(1)) # For higher version of PyG.
        out = out_add + out_sub
        out = self.node_mlp(out) # 将out传入node_mlp

        return out # y_hat

矩阵计算范式整体架构流程

利用矩阵计算图神经网络本质上就是将图的边看做邻接矩阵,图节点看做属性矩阵,每个邻居的参数看做权重矩阵,简而言之就是计算该公式:
矩阵
该方法在分布式中支持tensor parallel分布式方法

技术细节

这里用一个类来定义GCN层,其中核心部分的内容定义在forward函数里,input和adj代表节点的属性矩阵

import torch
import torch.nn as nn
class GCNLayer(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=False, batch_norm=False):
        super(GCNLayer, self).__init__()
        self.weight = torch.Tensor(in_features, out_features)
        self.weight = nn.Parameter(nn.init.xavier_uniform_(self.weight)) # 随机初始权重矩阵

        # 偏差这里没用到
        if bias:
            self.bias = torch.Tensor(out_features)
            self.bias = nn.Parameter(nn.init.xavier_uniform_(self.bias))
        else:
            self.register_parameter('bias', None)

        self.bn = nn.BatchNorm1d(out_features) if batch_norm else None

    def forward(self, input, adj, batch_norm=True): # GCN1 : input-原始特征X adj-当前邻接矩阵cur_adj
        support = torch.matmul(input, self.weight) # MP(X,W) = W × X  e.g.1 (2160,3) × (3,3) = (2160,3)
        output = torch.matmul(adj, support) # MP(cur_adj,support) = cur_adj × support = cur_adj × W × X(or Z(t)) | e.g. (2160,2160) × (2160,3) = (2160,3)

        if self.bias is not None:
            output = output + self.bias

        if self.bn is not None and batch_norm:
            output = self.compute_bn(output)

        return output

    def compute_bn(self, x):
        if len(x.shape) == 2:
            return self.bn(x)
        else:
            return self.bn(x.view(-1, x.size(-1))).view(x.size())
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值