DGL官方网址:https://www.dgl.ai/
背景
DGL是目前非常流行的图神经网络开发工具,基于pytorch实现。虽然DGL也包含tensorflow的API,但是教程是基于pytorch的,所以想使用DGL的同学建议提前掌握pytorch,这样学习DGL会更加高效和容易。
本篇是“图神经网络开发工具DGL源代码分析”的第一篇,这个系列主要是总结博主学习DGL过程中的心得体会,也可以看作是学习笔记吧。整理出来既是方便自己日后查阅,也是希望能和大家分享,帮助更多入坑图神经网络的同学。
dgl.DGLGraph.update_all实现message passing功能,是非常重要和基本的函数,今天我们一起梳理一下这个函数的源代码。
函数API
官网地址:https://docs.dgl.ai/generated/dgl.DGLGraph.update_all.html?highlight=update_all#dgl-dglgraph-update-all
调用路径
- update_all()
def update_all(self,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
"""Send messages along all the edges of the specified type
and update all the nodes of the corresponding destination type.
Parameters
----------
message_func : dgl.function.BuiltinFunction or callable
The message function to generate messages along the edges.
It must be either a :ref:`api-built-in` or a :ref:`apiudf`.
reduce_func : dgl.function.BuiltinFunction or callable
The reduce function to aggregate the messages.
It must be either a :ref:`api-built-in` or a :ref:`apiudf`.
apply_node_func : callable, optional
An optional apply function to further update the node features
after the message reduction. It must be a :ref:`apiudf`.
etype : str or (str, str, str), optional
The type name of the edges. The allowed type name formats are:
* ``(str, str, str)`` for source node type, edge type and destination node type.
* or one ``str`` edge type name if the name can uniquely identify a
triplet format in the graph.
Can be omitted if the graph has only one type of edges.
Notes
-----
* If some of the nodes in the graph has no in-edges, DGL does not invoke
message and reduce functions for these nodes and fill their aggregated messages
with zero. Users can control the filled values via :meth:`set_n_initializer`.
DGL still invokes :attr:`apply_node_func` if provided.
* DGL recommends using DGL's bulit-in function for the :attr:`message_func`
and the :attr:`reduce_func` arguments,
because DGL will invoke efficient kernels that avoids copying node features to
edge features in this case.
Examples
--------
>>> import dgl
>>> import dgl.function as fn
>>> import torch
**Homogeneous graph**
>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
>>> g.ndata['x'] = torch.ones(5, 2)
>>> g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'h'))
>>> g.ndata['h']
tensor([[0., 0.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])
**Heterogeneous graph**
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2], [1, 2, 2])})
Update all.
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g['follows'].update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows')
>>> g.nodes['user'].data['h']
tensor([[0.],
[0.],
[3.]])
"""
etid = self.get_etype_id(etype)
etype = self.canonical_etypes[etid]
_, dtid = self._graph.metagraph.find_edge(etid)
g = self if etype is None else self[etype]
ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
self._set_n_repr(dtid, ALL, ndata)
调用core.message_passing()计算node features,结果保存在ndata中。self._set_n_repr将ndata更新到graph里面。
- message_passing()
def message_passing(g, mfunc, rfunc, afunc):
"""Invoke message passing computation on the whole graph.
Parameters
----------
g : DGLGraph
The input graph.
mfunc : callable or dgl.function.BuiltinFunction
Message function.
rfunc : callable or dgl.function.BuiltinFunction
Reduce function.
afunc : callable or dgl.function.BuiltinFunction
Apply function.
Returns
-------
dict[str, Tensor]
Results from the message passing computation.
"""
if g.number_of_edges() == 0:
# No message passing is triggered.
ndata = {}
elif (is_builtin(mfunc) and is_builtin(rfunc) and
getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name), None) is not None):
# invoke fused message passing
ndata = invoke_gspmm(g, mfunc, rfunc)
else:
# invoke message passing in two separate steps
# message phase
if is_builtin(mfunc):
msgdata = invoke_gsddmm(g, mfunc)
else:
orig_eid = g.edata.get(EID, None)
msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
# reduce phase
if is_builtin(rfunc):
msg = rfunc.msg_field
ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata)
else:
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)
# apply phase
if afunc is not None:
for k, v in g.dstdata.items(): # include original node features
if k not in ndata:
ndata[k] = v
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_node_udf(g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid)
return ndata
这段代码由三个phase组成:message phase,reduce phase,apply phase。我们重点关注每个phase的user defined function接口,因为开发图神经网络模型时,通常自定义udf实现message passing。所以接下来分析invoke_edge_udf、invoke_udf_reduce、invoke_node_udf这三个函数。
- invoke_edge_udf()
def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None):
"""Invoke user-defined edge function on the given edges.
Parameters
----------
graph : DGLGraph
The input graph.
eid : Tensor
The IDs of the edges to invoke UDF on.
etype : (str, str, str)
Edge type.
func : callable
The user-defined function.
orig_eid : Tensor, optional
Original edge IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
etid = graph.get_etype_id(etype)
stid, dtid = graph._graph.metagraph.find_edge(etid)
if is_all(eid):
u, v, eid = graph.edges(form='all')
edata = graph._edge_frames[etid]
else:
u, v = graph.find_edges(eid)
edata = graph._edge_frames[etid].subframe(eid)
srcdata = graph._node_frames[stid].subframe(u)
dstdata = graph._node_frames[dtid].subframe(v)
ebatch = EdgeBatch(graph, eid if orig_eid is None else orig_eid,
etype, srcdata, edata, dstdata)
return func(ebatch)
message phase调用这个函数时,将整个graph传入,invoke_edge_udf会将其打包成EdgeBatch,送入udf中。所以我们自己定义的udf的入参是EdgeBatch。
- invoke_udf_reduce()
def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
"""Invoke user-defined reduce function on all the nodes in the graph.
It analyzes the graph, groups nodes by their degrees and applies the UDF on each
group -- a strategy called *degree-bucketing*.
Parameters
----------
graph : DGLGraph
The input graph.
func : callable
The user-defined function.
msgdata : dict[str, Tensor]
Message data.
orig_nid : Tensor, optional
Original node IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
degs = graph.in_degrees()
nodes = graph.dstnodes()
if orig_nid is None:
orig_nid = nodes
ntype = graph.dsttypes[0]
ntid = graph.get_ntype_id_from_dst(ntype)
dstdata = graph._node_frames[ntid]
msgdata = Frame(msgdata)
# degree bucketing
unique_degs, bucketor = _bucketing(degs)
bkt_rsts = []
bkt_nodes = []
for deg, node_bkt, orig_nid_bkt in zip(unique_degs, bucketor(nodes), bucketor(orig_nid)):
if deg == 0:
# skip reduce function for zero-degree nodes
continue
bkt_nodes.append(node_bkt)
ndata_bkt = dstdata.subframe(node_bkt)
# order the incoming edges per node by edge ID
eid_bkt = F.zerocopy_to_numpy(graph.in_edges(node_bkt, form='eid'))
assert len(eid_bkt) == deg * len(node_bkt)
eid_bkt = np.sort(eid_bkt.reshape((len(node_bkt), deg)), 1)
eid_bkt = F.zerocopy_from_numpy(eid_bkt.flatten())
msgdata_bkt = msgdata.subframe(eid_bkt)
# reshape all msg tensors to (num_nodes_bkt, degree, feat_size)
maildata = {}
for k, msg in msgdata_bkt.items():
newshape = (len(node_bkt), deg) + F.shape(msg)[1:]
maildata[k] = F.reshape(msg, newshape)
# invoke udf
nbatch = NodeBatch(graph, orig_nid_bkt, ntype, ndata_bkt, msgs=maildata)
bkt_rsts.append(func(nbatch))
# prepare a result frame
retf = Frame(num_rows=len(nodes))
retf._initializers = dstdata._initializers
retf._default_initializer = dstdata._default_initializer
# merge bucket results and write to the result frame
if len(bkt_rsts) != 0: # if all the nodes have zero degree, no need to merge results.
merged_rst = {}
for k in bkt_rsts[0].keys():
merged_rst[k] = F.cat([rst[k] for rst in bkt_rsts], dim=0)
merged_nodes = F.cat(bkt_nodes, dim=0)
retf.update_row(merged_nodes, merged_rst)
return retf
首先,根据node的in_degree进行分桶,相同in_degree的node组成一个NodeBatch送入自定义reduce函数进行处理。最后会将所有NodeBatch的计算结果进行合并,返回一个view(dict[str, tensor])。
注意:
(1)in_degree为0的node是不做处理的。这些node的tensor值由retf._initializers
决定。
(2)message phase阶段得到的tensor[edge_num,dim]会reshape为tensor[node_num, in_degree,dim]放入每个NodeBatch里面。
- invoke_node_udf()
def invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None):
"""Invoke user-defined node function on the given nodes.
Parameters
----------
graph : DGLGraph
The input graph.
eid : Tensor
The IDs of the nodes to invoke UDF on.
ntype : str
Node type.
func : callable
The user-defined function.
ndata : dict[str, Tensor], optional
If provided, apply the UDF on this ndata instead of the ndata of the graph.
orig_nid : Tensor, optional
Original node IDs. Useful if the input graph is an extracted subgraph.
Returns
-------
dict[str, Tensor]
Results from running the UDF.
"""
ntid = graph.get_ntype_id(ntype)
if ndata is None:
if is_all(nid):
ndata = graph._node_frames[ntid]
nid = graph.nodes(ntype=ntype)
else:
ndata = graph._node_frames[ntid].subframe(nid)
nbatch = NodeBatch(graph, nid if orig_nid is None else orig_nid, ntype, ndata)
return func(nbatch)
这个函数将输入打包成NodeBatch送入udf,所以apply phase这步我们的udf的输入依然是NodeBatch,和reduce phase一样。和reduce不一样的是,apply送入udf的是整个graph的数据,而reduce送入的是相同in_degree的Nodes的数据。
总结
虽然图具有不同于结构化数据的拓扑结构,但是在计算方式上依然将拓扑结构转化为tensor,比如reduce phase的tensor[node_num, in_degree, dim],需要的node和edge单独使用tensor标记id。dict[str, tensor]这样的数据结构也使用很多。
感谢大家能够阅读到这里,所以文末再分享一个学习源代码的好办法。我们可以写一段小程序,然后通过debug检查每一步的执行过程,看源代码是怎样实现这个功能的。比如分析update_all这个函数的源代码,我写的小程序是长这样子:
import torch
import dgl
graph=dgl.DGLGraph(([0,0,0,1,1,1],[1,2,3,4,5,6]))
graph.ndata['attr1']=2*torch.ones((len(graph.nodes()),1))
def message_func(edges):
return {'m':torch.ones(edges.batch_size(),1)}
def reduce_func(nodes):
return {'in_degree':nodes.mailbox['m'].sum(1)}
def update_func(nodes):
return {'tt':nodes.data['in_degree']+nodes.data['attr1']}
graph.update_all(message_func,reduce_func,update_func)
最后祝大家debug愉快~~