DGL0.5中的g-SpMM和g-SDDMM

DGL0.5中的g-SpMM和g-SDDMM

导读:之前对DGL0.5论文中的g-SpMM和g-SDDMM做了个简单的笔记,这次去DGL源码中看一下其相关的使用。使用pytorch中的GATConv作为入口。

  论文中提到说:

users can invoke the g-SpMM and g-SDDMM kernels via the
g.update_all($\phi$,$\rho$) and g.apply_edges($\phi$) calls on a DGLGraph.

  正好GAT中这两者都有用到,接下来跟一下这部分的代码:

gatconv.py中的forward部分:

"""gatconv.py"""
    el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
    er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
    graph.srcdata.update({'ft': feat_src, 'el': el})
    graph.dstdata.update({'er': er})
    # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
    graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
    e = self.leaky_relu(graph.edata.pop('e'))
    # compute softmax
    graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
    # message passing
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    rst = graph.dstdata['ft']

  可以看到计算边的attention过程中用到了graph.apply_edges(),最终消息传递时则用到了graph.update_all()graph就是DGLheterograph,下面看看apply_edges()update_all()这两个函数:

"""heterograph.py"""
    def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
    """Update the features of the specified edges by the provided function."""
        if inplace:
            raise DGLError('The `inplace` option is removed in v0.5.')
        etid = self.get_etype_id(etype)
        etype = self.canonical_etypes[etid]
        g = self if etype is None else self[etype]
        if is_all(edges):
            eid = ALL
        else:
            eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
        if core.is_builtin(func):
            if not is_all(eid):
                g = g.edge_subgraph(eid, preserve_nodes=True)
            edata = core.invoke_gsddmm(g, func)
        else:
            edata = core.invoke_edge_udf(g, eid, etype, func)
        self._set_e_repr(etid, eid, edata)
    
    
    def update_all(self,
                   message_func,
                   reduce_func,
                   apply_node_func=None,
                   etype=None):
        """Send messages along all the edges of the specified type
        and update all the nodes of the corresponding destination type."""
        
        etid = self.get_etype_id(etype)
        etype = self.canonical_etypes[etid]
        _, dtid = self._graph.metagraph.find_edge(etid)
        g = self if etype is None else self[etype]
        ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
        self._set_n_repr(dtid, ALL, ndata)

  可以发现,apply_edges()如果传入的func是DGL内置的话,就会调用core.invoke_gsddmm(g,func)得到所有边上的对应结果,而update_all()则会调用core.message_passing()得到所有结点上的对应结果。

  继续跟进,core.py完整路径为python/dgl/core.py.里面主要都是对graph上计算的实现。core.invoke_gsddmm(g, func)core.message_passing(g, message_func, reduce_func, apply_node_func)源码如下:

"""core.py"""
def invoke_gsddmm(graph, func):
    """Invoke g-SDDMM computation on the graph."""
    alldata = [graph.srcdata, graph.dstdata, graph.edata]
    if isinstance(func, fn.BinaryMessageFunction):
        x = alldata[func.lhs][func.lhs_field]
        y = alldata[func.rhs][func.rhs_field]
        op = getattr(ops, func.name)
        z = op(graph, x, y)
    else:
        x = alldata[func.target][func.in_field]
        op = getattr(ops, func.name)
        z = op(graph, x)
    return {func.out_field : z}

def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None):
    """Invoke g-SPMM computation on the graph."""
    # sanity check
    ...
    
    if isinstance(mfunc, fn.BinaryMessageFunction):
        x = alldata[mfunc.lhs][mfunc.lhs_field]
        y = alldata[mfunc.rhs][mfunc.rhs_field]
        op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
        z = op(graph, x, y)
    else:
        x = alldata[mfunc.target][mfunc.in_field]
        op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
        z = op(graph, x)
    return {rfunc.out_field : z}
    
    
def message_passing(g, mfunc, rfunc, afunc):
    """Invoke message passing computation on the whole graph."""
    if g.number_of_edges() == 0:
        # No message passing is triggered.
        ndata = {}
    elif (is_builtin(mfunc) and is_builtin(rfunc) and
          getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name), None) is not None):
        # invoke fused message passing
        ndata = invoke_gspmm(g, mfunc, rfunc)
    else:
        # 不是build in func时
        ...
    # apply phase
    if afunc is not None:
        ...
    return ndata

  可以看到message_passing中如果传入的func是DGL内置的,就会调用invoke_gspmm,这样与论文中说的就对应上了。
  invoke_gssdminvoke_gspmm都需要通过getattr得到op。通过debug,将进入到/python/dgl/ops/sddmm.py或相同路径下的spmm.py中调用gsddmm_internalgspmm_internal函数,该函数的执行应该取决于当前后端,如果当前是pytorch后端,就会跳转到/python/dgl/backend/pytorch/sparse.py中,开始gsddmm或者gspmm的forward运算。同样在sparse.py中还实现了gsddm和gspmm的backward函数,应该是用于梯度计算的。sparse.py的的成员如下图所示:

在这里插入图片描述

  在forward计算时,会调用/python/dgl/sparse.py中的_gsddmm()sparse.py中的_gsddmm()函数,其会调用DGL中C++实现的_CAPI_DGLKernelSDDMM()_CAPI_DGLKernelSpMM().

  _CAPI_DGLKernelSDDMM()_CAPI_DGLKernelSpMM()的定义处则是在源码/src/array/kernel.cc中,如下:

//kernel.cc
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    ...
    CheckShape(
        {graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
        {lhs_target, rhs_target, 1},
        {lhs, rhs, out},
        {"U_data", "E_data", "V_data"});
    SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target);
  });

  最终的SpMMSDDMM的计算则是调用/src/array/cpu/src/array/cuda目录下的各个类完成,而且SpMMSDDMM计算都有分别针对csrcoo格式的实现。至此就大概走完流程了。

  DGL自己实现了稀疏矩阵,首先看看调用_CAPI_DGLKernelSpMM()时传入的参数(/python/dgl/sparse.py中):

"""sparse.py"""
from . import backend as F
def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
    ...省略无关代码
    if gidx.number_of_edges(0) > 0:
        _CAPI_DGLKernelSDDMM(gidx, op,
                             to_dgl_nd(lhs if use_lhs else None),
                             to_dgl_nd(rhs if use_rhs else None),
                             to_dgl_nd_for_write(out),
                             lhs_target, rhs_target)

def to_dgl_nd(x):
    """Convert framework-specific tensor/None to dgl ndarray."""
    return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray(x)

  可以发现其中调用了to_dgl_nd,进一步调用了F.zerocopy_to_dgl_ndarray(x),可以发现/python/dgl/ndarray.py中同样用的是C++接口,如_CAPI_DGLSparseMatrixGetFormat(self),它们定义在/src/array/array.cc中:

//array.cc
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    SparseMatrixRef spmat = args[0];
    *rv = spmat->format;
  });

  array.cc中还定义了很多稀疏矩阵的操作,同目录下还有/cuda/cpu两个目录,其中定义了针对GPU和CPU的计算操作。

cpugpu
在这里插入图片描述在这里插入图片描述

  csrcoo等稀疏矩阵的数据结构则定义在/include/dgl/目录下。

延伸阅读:DGL现在还在开发新的tvm kernel,参考如下:
https://arxiv.org/abs/2008.11359
https://github.com/dmlc/dgl/pull/2136/files?file-filters%5B%5D=

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值