dgl源码阅读笔记(1)——update_all
图神经网络开源库dgl阅读笔记
文章目录
前言
update_all是dgl实现消息传递的重要函数,这里进行阅读了解
但是遗憾的是,目前没能彻底走完最底层的代码实现,希望有时间解决了遇到的问题。
对于这个大工程包,关于gsmm的调用及其的繁琐,很多层函数的套用。在学习的过程中,学到了修饰器,函数打包等在大型代码中是如何发挥作用的,也算是有点收获。
在debug过程是运行GCN模型后引用的该函数,有些数据集特征数,比如2708等,都是受数据集影响的,在文中为了方便理解所以写了下来。
一、update_all
查阅DGL手册了解到update_all的内容

g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata[‘x’] = torch.ones(5, 2)
g.update_all(fn.copy_u(‘x’, ‘m’), fn.sum(‘m’, ‘h’))
g.ndata[‘h’]
输出为
tensor([[0., 0.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])
对于该函数的源码进行分析,了解底层逻辑。
首先对边的种类进行判断,etype=None,在get_etype_id中,如果etype是None就返回0
因为较少直接调用的函数有点多,尽可能多的在代码中用注释的方式解释
canonical_etypes调用的是DGLHeteroGraph类方法,在初始化的过程中:
def __init__(self,
gidx=[],
ntypes=['_N'],
etypes=['_E'],
node_frames=None,
edge_frames=None,
**deprecate_kwargs):
于是对于没有metapath的同构图,就会默认返回出[‘_N’,‘_E’,‘_N’]的这样一条边
def update_all(self,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
# Graph with one relation type
if self._graph.number_of_etypes() == 1 or etype is not None:
etid = self.get_etype_id(etype)
etype = self.canonical_etypes[etid] #
_, dtid = self._graph.metagraph.find_edge(etid) # 输入边的种类,输出该边
\连接的src和dst的节点种类
g = self if etype is None else self[etype] # 只选取etype内的边构成的子图
ndata = core.message_passing(g, message_func, reduce_func, apply_node_func) # 进行消息
\传递
if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max'] and ndata:
# is_builtin用于判断传入的reduce_func方法是否是在dgl自带函数中
# Replace infinity with zero for isolated nodes
key = list(ndata.keys())[0]
ndata[key] = F.replace_inf_with_zero(ndata[key])
self._set_n_repr(dtid, ALL, ndata)
1、is_builtin
is_builtin用于判断传入的reduce_func参数是否是在dgl自带函数中
def is_builtin(func):
"""Return true if the function is a DGL builtin function."""
return isinstance(func, fn.BuiltinFunction)
2、message_passing
对整张图进行消息传递,输入图g和三个消息传递函数,消息函数,收缩函数,应用函数(默认为None)
但是也并没有直接执行消息传递的内容,而是先判断传入的参数,对应不同情况,执行invoke_gspmm实现消息传递
开始先判断两个函数是否是库支持的函数,否则就会进入到else中,去适应其他自定义的函数
然后调用invoke_gspmm(g, mfunc, rfunc)来实现
def message_passing(g, mfunc, rfunc, afunc):
if (is_builtin(mfunc) and is_builtin(rfunc) and
getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name), None) is not None):
# invoke fused message passing
ndata = invoke_gspmm(g, mfunc, rfunc)
else:
# invoke message passing in two separate steps
# message phase
if is_builtin(mfunc):
msgdata = invoke_gsddmm(g, mfunc)
else:
orig_eid = g.edata.get(EID, None)
msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
# reduce phase
if is_builtin(rfunc):
msg = rfunc.msg_field
ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata)
else:
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)
# apply phase
if afunc is not None:
for k, v in g.dstdata.items(): # include original node features
if k not in ndata:
ndata[k] = v
orig_nid = g.dstdata.get(NID, None)
ndata = invoke_node_udf(g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid)
return ndata
最后将invoke后的ndata传出,返回到update_all中的ndata
invoke_gspmm
上面说到了调用这个函数,这里看一下是如何实现的
需要注意到的是本函数的实现基于DGL团队2019年提出的g-SPMM(稀疏矩阵乘法)方法
这个是论文的链接
https://arxiv.org/abs/1909.01315
首先是函数的定义和参数的介绍
def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None):
"""Invoke g-SPMM computation on the graph.
Parameters
----------
graph : DGLGraph
The input graph.
mfunc : dgl.function.BaseMessageFunction
Built-in message function.
rfunc : dgl.function.BaseReduceFunction
Built-in reduce function.
srcdata : dict[str, Tensor], optional
Source node feature data. If not provided, it use ``graph.srcdata``.
dstdata : dict[str, Tensor], optional
Destination node feature data. If not provided, it use ``graph.dstdata``.
edata : dict[str, Tensor], optional
Edge feature data. If not provided, it use ``graph.edata``.
Returns
-------
dict[str, Tensor]
Results from the g-SPMM computation.
"""

最低0.47元/天 解锁文章
1985

被折叠的 条评论
为什么被折叠?



