dgl源码阅读笔记(1)——update_all

dgl源码阅读笔记(1)——update_all

图神经网络开源库dgl阅读笔记



前言

update_all是dgl实现消息传递的重要函数,这里进行阅读了解
但是遗憾的是,目前没能彻底走完最底层的代码实现,希望有时间解决了遇到的问题。
对于这个大工程包,关于gsmm的调用及其的繁琐,很多层函数的套用。在学习的过程中,学到了修饰器,函数打包等在大型代码中是如何发挥作用的,也算是有点收获。
在debug过程是运行GCN模型后引用的该函数,有些数据集特征数,比如2708等,都是受数据集影响的,在文中为了方便理解所以写了下来。


一、update_all

查阅DGL手册了解到update_all的内容
在这里插入图片描述

g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata[‘x’] = torch.ones(5, 2)
g.update_all(fn.copy_u(‘x’, ‘m’), fn.sum(‘m’, ‘h’))
g.ndata[‘h’]
输出为
tensor([[0., 0.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])

对于该函数的源码进行分析,了解底层逻辑。
首先对边的种类进行判断,etype=None,在get_etype_id中,如果etype是None就返回0
因为较少直接调用的函数有点多,尽可能多的在代码中用注释的方式解释
canonical_etypes调用的是DGLHeteroGraph类方法,在初始化的过程中:

def __init__(self,
                 gidx=[],
                 ntypes=['_N'],
                 etypes=['_E'],
                 node_frames=None,
                 edge_frames=None,
                 **deprecate_kwargs):

于是对于没有metapath的同构图,就会默认返回出[‘_N’,‘_E’,‘_N’]的这样一条边

    def update_all(self,
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
        # Graph with one relation type
        if self._graph.number_of_etypes() == 1 or etype is not None:
            etid = self.get_etype_id(etype)
            etype = self.canonical_etypes[etid] # 
            _, dtid = self._graph.metagraph.find_edge(etid) # 输入边的种类,输出该边
           	 													\连接的src和dst的节点种类
            g = self if etype is None else self[etype]  # 只选取etype内的边构成的子图
            ndata = core.message_passing(g, message_func, reduce_func, apply_node_func) # 进行消息
            																			  \传递
            if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max'] and ndata:
            	# is_builtin用于判断传入的reduce_func方法是否是在dgl自带函数中
                # Replace infinity with zero for isolated nodes
                key = list(ndata.keys())[0]
                ndata[key] = F.replace_inf_with_zero(ndata[key])
            self._set_n_repr(dtid, ALL, ndata)

1、is_builtin

is_builtin用于判断传入的reduce_func参数是否是在dgl自带函数中

def is_builtin(func):
    """Return true if the function is a DGL builtin function."""
    return isinstance(func, fn.BuiltinFunction)

2、message_passing

对整张图进行消息传递,输入图g和三个消息传递函数,消息函数,收缩函数,应用函数(默认为None)
但是也并没有直接执行消息传递的内容,而是先判断传入的参数,对应不同情况,执行invoke_gspmm实现消息传递
开始先判断两个函数是否是库支持的函数,否则就会进入到else中,去适应其他自定义的函数
然后调用invoke_gspmm(g, mfunc, rfunc)来实现

def message_passing(g, mfunc, rfunc, afunc):
    if (is_builtin(mfunc) and is_builtin(rfunc) and
            getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name), None) is not None):
        # invoke fused message passing
        ndata = invoke_gspmm(g, mfunc, rfunc)
    else:
        # invoke message passing in two separate steps
        # message phase
        if is_builtin(mfunc):
            msgdata = invoke_gsddmm(g, mfunc)
        else:
            orig_eid = g.edata.get(EID, None)
            msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
        # reduce phase
        if is_builtin(rfunc):
            msg = rfunc.msg_field
            ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata)
        else:
            orig_nid = g.dstdata.get(NID, None)
            ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)
    # apply phase
    if afunc is not None:
        for k, v in g.dstdata.items():   # include original node features
            if k not in ndata:
                ndata[k] = v
        orig_nid = g.dstdata.get(NID, None)
        ndata = invoke_node_udf(g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid)
    return ndata

最后将invoke后的ndata传出,返回到update_all中的ndata

invoke_gspmm

上面说到了调用这个函数,这里看一下是如何实现的
需要注意到的是本函数的实现基于DGL团队2019年提出的g-SPMM(稀疏矩阵乘法)方法
这个是论文的链接
https://arxiv.org/abs/1909.01315
首先是函数的定义和参数的介绍

def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None):
    """Invoke g-SPMM computation on the graph.

    Parameters
    ----------
    graph :  DGLGraph
        The input graph.
    mfunc : dgl.function.BaseMessageFunction
        Built-in message function.
    rfunc : dgl.function.BaseReduceFunction
        Built-in reduce function.
    srcdata : dict[str, Tensor], optional
        Source node feature data. If not provided, it use ``graph.srcdata``.
    dstdata : dict[str, Tensor], optional
        Destination node feature data. If not provided, it use ``graph.dstdata``.
    edata : dict[str, Tensor], optional
        Edge feature data. If not provided, it use ``graph.edata``.

    Returns
    -------
    dict[str, Tensor]
        Results from the g-SPMM computation.
    """
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>