图神经网络开发工具DGL源代码分析(1):update_all


DGL官方网址:https://www.dgl.ai/

背景

DGL是目前非常流行的图神经网络开发工具,基于pytorch实现。虽然DGL也包含tensorflow的API,但是教程是基于pytorch的,所以想使用DGL的同学建议提前掌握pytorch,这样学习DGL会更加高效和容易。

本篇是“图神经网络开发工具DGL源代码分析”的第一篇,这个系列主要是总结博主学习DGL过程中的心得体会,也可以看作是学习笔记吧。整理出来既是方便自己日后查阅,也是希望能和大家分享,帮助更多入坑图神经网络的同学。

dgl.DGLGraph.update_all实现message passing功能,是非常重要和基本的函数,今天我们一起梳理一下这个函数的源代码。

函数API

官网地址:https://docs.dgl.ai/generated/dgl.DGLGraph.update_all.html?highlight=update_all#dgl-dglgraph-update-all

在这里插入图片描述

调用路径

  • update_all()
    def update_all(self,
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
        """Send messages along all the edges of the specified type
        and update all the nodes of the corresponding destination type.

        Parameters
        ----------
        message_func : dgl.function.BuiltinFunction or callable
            The message function to generate messages along the edges.
            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.
        reduce_func : dgl.function.BuiltinFunction or callable
            The reduce function to aggregate the messages.
            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.
        apply_node_func : callable, optional
            An optional apply function to further update the node features
            after the message reduction. It must be a :ref:`apiudf`.
        etype : str or (str, str, str), optional
            The type name of the edges. The allowed type name formats are:

            * ``(str, str, str)`` for source node type, edge type and destination node type.
            * or one ``str`` edge type name if the name can uniquely identify a
              triplet format in the graph.

            Can be omitted if the graph has only one type of edges.

        Notes
        -----
        * If some of the nodes in the graph has no in-edges, DGL does not invoke
          message and reduce functions for these nodes and fill their aggregated messages
          with zero. Users can control the filled values via :meth:`set_n_initializer`.
          DGL still invokes :attr:`apply_node_func` if provided.
        * DGL recommends using DGL's bulit-in function for the :attr:`message_func`
          and the :attr:`reduce_func` arguments,
          because DGL will invoke efficient kernels that avoids copying node features to
          edge features in this case.

        Examples
        --------
        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch

        **Homogeneous graph**

        >>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
        >>> g.ndata['x'] = torch.ones(5, 2)
        >>> g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'h'))
        >>> g.ndata['h']
        tensor([[0., 0.],
                [1., 1.],
                [1., 1.],
                [1., 1.],
                [1., 1.]])

        **Heterogeneous graph**

        >>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2], [1, 2, 2])})

        Update all.

        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
        >>> g['follows'].update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
        >>> g.nodes['user'].data['h']
        tensor([[0.],
                [0.],
                [3.]])
        """        
        etid = self.get_etype_id(etype)
        etype = self.canonical_etypes[etid]
        _, dtid = self._graph.metagraph.find_edge(etid)
        g = self if etype is None else self[etype]
        ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
        self._set_n_repr(dtid, ALL, ndata)

调用core.message_passing()计算node features,结果保存在ndata中。self._set_n_repr将ndata更新到graph里面。

  • message_passing()
def message_passing(g, mfunc, rfunc, afunc):
     """Invoke message passing computation on the whole graph.

    Parameters
    ----------
    g : DGLGraph
        The input graph.
    mfunc : callable or dgl.function.BuiltinFunction
        Message function.
    rfunc : callable or dgl.function.BuiltinFunction
        Reduce function.
    afunc : callable or dgl.function.BuiltinFunction
        Apply function.

    Returns
    -------
    dict[str, Tensor]
        Results from the message passing computation.
    """   
    if g.number_of_edges() == 0:
        # No message passing is triggered.
        ndata = {}
    elif (is_builtin(mfunc) and is_builtin(rfunc) and
          getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name), None) is not None):
        # invoke fused message passing
        ndata = invoke_gspmm(g, mfunc, rfunc)
    else:
        # invoke message passing in two separate steps
        # message phase
        if is_builtin(mfunc):
            msgdata = invoke_gsddmm(g, mfunc)
        else:
            orig_eid = g.edata.get(EID, None)
            msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
        # reduce phase
        if is_builtin(rfunc):
            msg = rfunc.msg_field
            ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata)
        else:
            orig_nid = g.dstdata.get(NID, None)
            ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)
    # apply phase
    if afunc is not None:
        for k, v in g.dstdata.items():   # include original node features
            if k not in ndata:
                ndata[k] = v
        orig_nid = g.dstdata.get(NID, None)
        ndata = invoke_node_udf(g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid)
    return ndata

这段代码由三个phase组成:message phase,reduce phase,apply phase。我们重点关注每个phase的user defined function接口,因为开发图神经网络模型时,通常自定义udf实现message passing。所以接下来分析invoke_edge_udf、invoke_udf_reduce、invoke_node_udf这三个函数。

  • invoke_edge_udf()
def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None):
    """Invoke user-defined edge function on the given edges.

    Parameters
    ----------
    graph : DGLGraph
        The input graph.
    eid : Tensor
        The IDs of the edges to invoke UDF on.
    etype : (str, str, str)
        Edge type.
    func : callable
        The user-defined function.
    orig_eid : Tensor, optional
        Original edge IDs. Useful if the input graph is an extracted subgraph.

    Returns
    -------
    dict[str, Tensor]
        Results from running the UDF.
    """
    etid = graph.get_etype_id(etype)
    stid, dtid = graph._graph.metagraph.find_edge(etid)
    if is_all(eid):
        u, v, eid = graph.edges(form='all')
        edata = graph._edge_frames[etid]
    else:
        u, v = graph.find_edges(eid)
        edata = graph._edge_frames[etid].subframe(eid)
    srcdata = graph._node_frames[stid].subframe(u)
    dstdata = graph._node_frames[dtid].subframe(v)
    ebatch = EdgeBatch(graph, eid if orig_eid is None else orig_eid,
                       etype, srcdata, edata, dstdata)
    return func(ebatch)

message phase调用这个函数时,将整个graph传入,invoke_edge_udf会将其打包成EdgeBatch,送入udf中。所以我们自己定义的udf的入参是EdgeBatch。

  • invoke_udf_reduce()
def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
    """Invoke user-defined reduce function on all the nodes in the graph.

    It analyzes the graph, groups nodes by their degrees and applies the UDF on each
    group -- a strategy called *degree-bucketing*.

    Parameters
    ----------
    graph : DGLGraph
        The input graph.
    func : callable
        The user-defined function.
    msgdata : dict[str, Tensor]
        Message data.
    orig_nid : Tensor, optional
        Original node IDs. Useful if the input graph is an extracted subgraph.

    Returns
    -------
    dict[str, Tensor]
        Results from running the UDF.
    """
    degs = graph.in_degrees()
    nodes = graph.dstnodes()
    if orig_nid is None:
        orig_nid = nodes
    ntype = graph.dsttypes[0]
    ntid = graph.get_ntype_id_from_dst(ntype)
    dstdata = graph._node_frames[ntid]
    msgdata = Frame(msgdata)

    # degree bucketing
    unique_degs, bucketor = _bucketing(degs)
    bkt_rsts = []
    bkt_nodes = []
    for deg, node_bkt, orig_nid_bkt in zip(unique_degs, bucketor(nodes), bucketor(orig_nid)):
        if deg == 0:
            # skip reduce function for zero-degree nodes
            continue
        bkt_nodes.append(node_bkt)
        ndata_bkt = dstdata.subframe(node_bkt)

        # order the incoming edges per node by edge ID
        eid_bkt = F.zerocopy_to_numpy(graph.in_edges(node_bkt, form='eid'))
        assert len(eid_bkt) == deg * len(node_bkt)
        eid_bkt = np.sort(eid_bkt.reshape((len(node_bkt), deg)), 1)
        eid_bkt = F.zerocopy_from_numpy(eid_bkt.flatten())

        msgdata_bkt = msgdata.subframe(eid_bkt)
        # reshape all msg tensors to (num_nodes_bkt, degree, feat_size)
        maildata = {}
        for k, msg in msgdata_bkt.items():
            newshape = (len(node_bkt), deg) + F.shape(msg)[1:]
            maildata[k] = F.reshape(msg, newshape)
        # invoke udf
        nbatch = NodeBatch(graph, orig_nid_bkt, ntype, ndata_bkt, msgs=maildata)
        bkt_rsts.append(func(nbatch))

    # prepare a result frame
    retf = Frame(num_rows=len(nodes))
    retf._initializers = dstdata._initializers
    retf._default_initializer = dstdata._default_initializer

    # merge bucket results and write to the result frame
    if len(bkt_rsts) != 0:  # if all the nodes have zero degree, no need to merge results.
        merged_rst = {}
        for k in bkt_rsts[0].keys():
            merged_rst[k] = F.cat([rst[k] for rst in bkt_rsts], dim=0)
        merged_nodes = F.cat(bkt_nodes, dim=0)
        retf.update_row(merged_nodes, merged_rst)

    return retf

首先,根据node的in_degree进行分桶,相同in_degree的node组成一个NodeBatch送入自定义reduce函数进行处理。最后会将所有NodeBatch的计算结果进行合并,返回一个view(dict[str, tensor])。

注意:
(1)in_degree为0的node是不做处理的。这些node的tensor值由retf._initializers决定。
(2)message phase阶段得到的tensor[edge_num,dim]会reshape为tensor[node_num, in_degree,dim]放入每个NodeBatch里面。

  • invoke_node_udf()
def invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None):
    """Invoke user-defined node function on the given nodes.

    Parameters
    ----------
    graph : DGLGraph
        The input graph.
    eid : Tensor
        The IDs of the nodes to invoke UDF on.
    ntype : str
        Node type.
    func : callable
        The user-defined function.
    ndata : dict[str, Tensor], optional
        If provided, apply the UDF on this ndata instead of the ndata of the graph.
    orig_nid : Tensor, optional
        Original node IDs. Useful if the input graph is an extracted subgraph.

    Returns
    -------
    dict[str, Tensor]
        Results from running the UDF.
    """
    ntid = graph.get_ntype_id(ntype)
    if ndata is None:
        if is_all(nid):
            ndata = graph._node_frames[ntid]
            nid = graph.nodes(ntype=ntype)
        else:
            ndata = graph._node_frames[ntid].subframe(nid)
    nbatch = NodeBatch(graph, nid if orig_nid is None else orig_nid, ntype, ndata)
    return func(nbatch)

这个函数将输入打包成NodeBatch送入udf,所以apply phase这步我们的udf的输入依然是NodeBatch,和reduce phase一样。和reduce不一样的是,apply送入udf的是整个graph的数据,而reduce送入的是相同in_degree的Nodes的数据。

总结

虽然图具有不同于结构化数据的拓扑结构,但是在计算方式上依然将拓扑结构转化为tensor,比如reduce phase的tensor[node_num, in_degree, dim],需要的node和edge单独使用tensor标记id。dict[str, tensor]这样的数据结构也使用很多。

感谢大家能够阅读到这里,所以文末再分享一个学习源代码的好办法。我们可以写一段小程序,然后通过debug检查每一步的执行过程,看源代码是怎样实现这个功能的。比如分析update_all这个函数的源代码,我写的小程序是长这样子:

import torch
import dgl

graph=dgl.DGLGraph(([0,0,0,1,1,1],[1,2,3,4,5,6]))
graph.ndata['attr1']=2*torch.ones((len(graph.nodes()),1))

def message_func(edges):
    return {'m':torch.ones(edges.batch_size(),1)}

def reduce_func(nodes):
    return {'in_degree':nodes.mailbox['m'].sum(1)}  

def update_func(nodes):
    return {'tt':nodes.data['in_degree']+nodes.data['attr1']}

graph.update_all(message_func,reduce_func,update_func)

最后祝大家debug愉快~~

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
# GPF ## 一、GPF(Graph Processing Flow):利用神经网络处理问题的一般化流程 1、节点预表示:利用NE框架,直接获得全每个节点的Embedding; 2、正负样本采样:(1)单节点样本;(2)节点对样本; 3、抽取封闭子:可做类化处理,建立一种通用数据结构; 4、子特征融合:预表示、节点特征、全局特征、边特征; 5、网络配置:可以是输入、输出的网络;也可以是输入,分类/聚类结果输出的网络; 6、训练和测试; ## 二、主要文件: 1、graph.py:读入数据; 2、embeddings.py:预表示学习; 3、sample.py:采样; 4、subgraphs.py/s2vGraph.py:抽取子; 5、batchgraph.py:子特征融合; 6、classifier.py:网络配置; 7、parameters.py/until.py:参数配置/帮助文件; ## 三、使用 1、在parameters.py中配置相关参数(可默认); 2、在example/文件夹中运行相应的案例文件--包括链接预测、节点状态预测; 以链接预测为例: ### 1、导入配置参数 ```from parameters import parser, cmd_embed, cmd_opt``` ### 2、参数转换 ``` args = parser.parse_args() args.cuda = not args.noCuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if args.hop != 'auto': args.hop = int(args.hop) if args.maxNodesPerHop is not None: args.maxNodesPerHop = int(args.maxNodesPerHop) ``` ### 3、读取数据 ``` g = graph.Graph() g.read_edgelist(filename=args.dataName, weighted=args.weighted, directed=args.directed) g.read_node_status(filename=args.labelName) ``` ### 4、获取全节点的Embedding ``` embed_args = cmd_embed.parse_args() embeddings = embeddings.learn_embeddings(g, embed_args) node_information = embeddings #print node_information ``` ### 5、正负节点采样 ``` train, train_status, test, test_status = sample.sample_single(g, args.testRatio, max_train_num=args.maxTrainNum) ``` ### 6、抽取节点对的封闭子 ``` net = until.nxG_to_mat(g) #print net train_graphs, test_graphs, max_n_label = subgraphs.singleSubgraphs(net, train, train_status, test, test_status, args.hop, args.maxNodesPerHop, node_information) print('# train: %d, # test: %d' % (len(train_graphs), len(test_graphs))) ``` ### 7、加载网络模型,并在classifier中配置相关参数 ``` cmd_args = cmd_opt.parse_args() cmd_args.feat_dim = max_n_label + 1 cmd_args.attr_dim = node_information.shape[1] cmd_args.latent_dim = [int(x) for x in cmd_args.latent_dim.split('-')] if len(cmd_args.latent_dim) == 1: cmd_args.latent_dim = cmd_args.latent_dim[0] model = classifier.Classifier(cmd_args) optimizer = optim.Adam(model.parameters(), lr=args.learningRate) ``` ### 8、训练和测试 ``` train_idxes = list(range(len(train_graphs))) best_loss = None for epoch in range(args.num_epochs): random.shuffle(train_idxes) model.train() avg_loss = loop_dataset(train_graphs, model, train_idxes, cmd_args.batch_size, optimizer=optimizer) print('\033[92maverage training of epoch %d: loss %.5f acc %.5f auc %.5f\033[0m' % (epoch, avg_loss[0], avg_loss[1], avg_loss[2])) model.eval() test_loss = loop_dataset(test_graphs, model, list(range(len(test_graphs))), cmd_args.batch_size) print('\033[93maverage test of epoch %d: loss %.5f acc %.5f auc %.5f\033[0m' % (epoch, test_loss[0], test_loss[1], test_loss[2])) ``` ### 9、运行结果 ``` average test of epoch 0: loss 0.62392 acc 0.71462 auc 0.72314 loss: 0.51711 acc: 0.80000: 100%|███████████████████████████████████| 76/76 [00:07<00:00, 10.09batch/s] average training of epoch 1: loss 0.54414 acc 0.76895 auc 0.77751 loss: 0.37699 acc: 0.79167: 100%|█████████████████████████████████████| 9/9 [00:00<00:00, 34.07batch/s] average test of epoch 1: loss 0.51981 acc 0.78538 auc 0.79709 loss: 0.43700 acc: 0.84000: 100%|███████████████████████████████████| 76/76 [00:07<00:00, 9.64batch/s] average training of epoch 2: loss 0.49896 acc 0.79184 auc 0.82246 loss: 0.63594 acc: 0.66667: 100%|█████████████████████████████████████| 9/9 [00:00<00:00, 28.62batch/s] average test of epoch 2: loss 0.48979 acc 0.79481 auc 0.83416 loss: 0.57502 acc: 0.76000: 100%|███████████████████████████████████| 76/76 [00:07<00:00, 9.70batch/s] average training of epoch 3: loss 0.50005 acc 0.77447 auc 0.79622 loss: 0.38903 acc: 0.75000: 100%|█████████████████████████████████████| 9/9 [00:00<00:00, 34.03batch/s] average test of epoch 3: loss 0.41463 acc 0.81132 auc 0.86523 loss: 0.54336 acc: 0.76000: 100%|███████████████████████████████████| 76/76 [00:07<00:00, 9.57batch/s] average training of epoch 4: loss 0.44815 acc 0.81711 auc 0.84530 loss: 0.44784 acc: 0.70833: 100%|█████████████████████████████████████| 9/9 [00:00<00:00, 28.62batch/s] average test of epoch 4: loss 0.48319 acc 0.81368 auc 0.84454 loss: 0.36999 acc: 0.88000: 100%|███████████████████████████████████| 76/76 [00:07<00:00, 10.17batch/s] average training of epoch 5: loss 0.39647 acc 0.84184 auc 0.89236 loss: 0.15548 acc: 0.95833: 100%|█████████████████████████████████████| 9/9 [00:00<00:00, 28.62batch/s] average test of epoch 5: loss 0.30881 acc 0.89623 auc 0.95132 ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值