跟着官方文档学DGL框架第五天——自定义GNN模块(GraphSAGE实现)

参考链接

  1. https://docs.dgl.ai/guide/nn.html#guide-nn

如果DGL没有你想要的GNN模块,可以根据自己的需求定义(感觉应该放在后面讲,容易劝退像我这种小白)。本节以GraphSAGE为例。

与pytorch类似,构造函数完成以下几个任务:

  1. 设置选项
  2. 注册可学习的参数或者子模块
  3. 初始化参数

演示代码如下:

import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation

由于一个节点既可以是源节点,亦可以是目标节点,所以需要(“self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)”)把特征分成两部分,一部分作为源节点的时候用,一部分作为目标节点的时候用。

注册可学习的参数或者子模块

“self._aggre_type”是消息聚合函数类型,常见的有“mean”、“sum”、“max”、“min”和“lstm”等。

        # 聚合类型:mean、max_pool、lstm、gcn
        if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'max_pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'max_pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()

初始化参数

如果是“gcn”聚合类型,需要单独考虑。这一点在上面的代码“if aggregator_type in [‘mean’, ‘max_pool’, ‘lstm’]”以及后面的代码都有体现。

    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

接下来是forward()函数:

消息产生和传递

公式为

h N ( d s t ) ( l + 1 ) = a g g r e g a t e ( { h s r c ( l ) , ∀ s r c ∈ N ( d s t ) } ) h_{N\left ( dst\right )}^{\left ( l+1\right )}=aggregate\left ( \left \{h_{src}^{\left ( l\right )},\forall src\in N\left ( dst\right )\right \}\right ) hN(dst)(l+1)=aggregate({hsrc(l),srcN(dst)})

消息产生和消息传递过程一起写在了“graph.update_all()”
从代码上看:

  1. “mean”聚合类型,就是将目标节点邻居的特征求平均作为邻居的消息;
  2. “gcn”聚合类型,就是将目标节点邻居特征的和与自身的特征求和后除以入度,作为邻居的消息;
  3. “max_pool”聚合类型,就是将原始特征经过一次非线性变换后作为源节点特征,然后将目标节点邻居中最大的特征作为邻居的消息。
    def forward(self, graph, feat):
        with graph.local_scope():
            # 指定图类型,然后根据图类型扩展输入特征
            feat_src, feat_dst = expand_as_pair(feat, graph)

        import dgl.function as fn
        import torch.nn.functional as F
        from dgl.utils import check_eq_shape

        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'max_pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

消息聚合

公式为

h d s t ( l + 1 ) = σ ( W ⋅ c o n c a t ( h d s t ( l ) , h N ( d s t ) ( l + 1 ) ) + b ) h_{dst}^{\left ( l+1\right )}=\sigma \left ( W\cdot concat\left ( h_{dst}^{\left ( l\right )},h_{N\left ( dst\right )}^{\left ( l+1\right )}\right )+b\right ) hdst(l+1)=σ(Wconcat(hdst(l),hN(dst)(l+1))+b)

h d s t ( l + 1 ) = n o r m ( h d s t ( l + 1 ) ) h_{dst}^{\left ( l+1\right )}=norm\left ( h_{dst}^{\left ( l+1\right )}\right ) hdst(l+1)=norm(hdst(l+1))

“gcn”聚合类型并不会将目标节点自己的特征做一次线性变换,而是只将邻居的消息做一次线性变换作为新的消息。

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

似乎首次出现了“graph.srcdata[‘h’]”和“graph.dstdata[‘h’]”访问节点数据的方式,印证了前面提到的输入特征分为两部分的思想。猜测这样也是为了与“g.ndata[‘h’]”的方式访问同名特征区分。

输出更新后的特征

对新消息加上激活函数和归一化作为新的特征

        # 激活函数
        if self.activation is not None:
            rst = self.activation(rst)
        # 归一化
        if self.norm is not None:
            rst = self.norm(rst)
        return rst

“完整”代码

将文档提供的代码排版后如下,还存在几个问题:

  1. 不支持“lstm”聚合类型
  2. 目标节点自身的消息“h_self”并没有定义
  3. 文档中公式书写有些许问题,已经更正
import torch.nn as nn

from dgl.utils import expand_as_pair

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation


        # 聚合类型:mean、max_pool、lstm、gcn
        if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'max_pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'max_pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()



    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def forward(self, graph, feat):
        with graph.local_scope():
            # 指定图类型,然后根据图类型扩展输入特征
            feat_src, feat_dst = expand_as_pair(feat, graph)

        import dgl.function as fn
        import torch.nn.functional as F
        from dgl.utils import check_eq_shape

        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'max_pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

        # 激活函数
        if self.activation is not None:
            rst = self.activation(rst)
        # 归一化
        if self.norm is not None:
            rst = self.norm(rst)
        return rst

代码中有一个“expand_as_pair()”函数,用于根据图类型,将输入特征分为两部分。具体细节如下(目前只需要看同构图的情况就行,其实就是简单的复制了一份):

def expand_as_pair(input_, g=None):
    if isinstance(input_, tuple):
        # 二分图的情况
        return input_
    elif g is not None and g.is_block:
        # 子图块的情况
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                for k, v in input_.items()}
        else:
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
        return input_, input_dst
    else:
        # 同构图的情况
        return input_, input_
  • 13
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值