DGL入门:自己定义的第一个图卷积层:消息传递聚合 以及 图卷积
1. 理解什么是DGL的消息传递流程
DGL的消息传递一共有四个步骤:
- 消息传递:针对所有的边,我们将在这个函数里面定义如何将边的信息和源节点的信息进行计算,计算完成之后的结果将“寄送”到目标节点的mailbox中
def message_func(edges): # dges 有 src、 dst 和 data 共3个成员属性, 分别用于访问源节点、目标节点和边的特征 # 下面的例子中,目标节点的mailbox会收到相关的信息: m return {'m': edges.src['h'] + edges.dst['h']}
- 聚合函数:针对某一个点,他可能连接着若干条边。每一条边在经过消息传递之后都会有自己的 mailbox,而聚合函数需要做的是如何把mailbox的结果进行整合
def reduce_func(nodes): # nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 sum、max、min 等 # 下面的例子中, 我们在这个函数中对 mailbox 中收到的信息进行聚合,并将结果存放在各个节点的'h'属性中 return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
- 节点更新函数:
至此,我们完成了DGL消息传递的全部流程,我们在这里定义了一个字典:def apply_node_func(nodes): return {'h': nodes.data['h'] + nodes.data['h']}
func_dict = { 'message_func': message_func, 'reduce_func': reduce_func, 'apply_node_func': apply_node_func, }
- 图节点的全部更新
graph.update_all(message_func, reduce_func) # 又或者 graph.update_all(message_func, reduce_func, apply_node_func)
2. 自己定义的第一个图卷积层
这里我们根据上述的内容,设计除自己的定义的第一个图卷积层
class UserDefinedGraphConv(nn.Module):
def __init__(self, conv_config):
self._in_feats = conv_config['in_feats']
self._out_feats = conv_config['out_fetas']
if 'weight' in conv_config.keys() and conv_config['weight'] == None:
self.register_parameter('weight', None)
else:
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
if 'bias' in conv_config.keys() and conv_config['bias'] == None:
self.register_parameter('bias', None)
else:
self.bias = nn.Parameter(th.Tensor(out_feats))
self._activation = conv_config['activation']
self.reset_parameter()
# 在这里继续定义相关的参数
def reset_parameter(self):
if self.weight is not None:
init.xavier_uniform_(self.weight)
if self.bias is not None:
init.zeros_(self.bias)
def forward(self, graph, feats, func_dict, weight=None):
with graph.scope():
# 对图的相关检验在这里做
# tear apart the input feature into source nodes’ feature and destination nodes’ feature.
feat_src, feat_dst = expand_as_pair(feat, graph)
if weight is not None:
if self.weight is not None:
raise Exception('External weight is provided while at the same time the module has defined its own weight parameter. Please create the module with flag weight=False.')
else:
weight = self.weight
# 在卷积之前可以有相关的Normalization
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat_src = th.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src
graph.update_all(func_dict['message_func'], func_dict['reduce_func'])
# 或者:
# graph.update_all(func_dict['message_func'], func_dict['reduce_func'], func_dict['apply_node_func'])
rst = graph.dstdata['h']
else:
graph.srcdata['h'] = feat_src
graph.update_all(func_dict['message_func'], func_dict['reduce_func'])
# 或者:
# graph.update_all(func_dict['message_func'], func_dict['reduce_func'], func_dict['apply_node_func'])
rst = graph.dstdata['h']
if weight is not None:
rst = th.matmul(rst, weight)
# 在卷积之后也可以有相关的Normalization
if self.bias is not None:
rst = rst + self.bias
if self._activation is not None:
rst = self._activation(rst)
return rst
def extra_repr(self):
pass
3. 参考资料
- https://discuss.dgl.ai/t/expand-as-pair-explanation/1311