第3章:构建图神经网络(GNN)模块
DGL NN模块是用户构建GNN模型的基本模块。根据DGL所使用的后端深度神经网络框架, DGL NN模块的父类取决于后端所使用的深度神经网络框架。对于PyTorch后端, 它应该继承 PyTorch的NN模块;对于MXNet后端,它应该继承 MXNet Gluon的NN块; 对于TensorFlow后端,它应该继承 Tensorflow的Keras层。 在DGL NN模块中,构造函数中的参数注册和前向传播函数中使用的张量操作与后端框架一样。这种方式使得DGL的代码可以无缝嵌入到后端框架的代码中。 DGL和这些深度神经网络框架的主要差异是其独有的消息传递操作。
DGL已经集成了很多常用的 apinn-pytorch-conv、 apinn-pytorch-dense-conv、 apinn-pytorch-pooling 和 apinn-pytorch-util。欢迎给DGL贡献更多的模块!
本章将使用PyTorch作为后端,用 SAGEConv 作为例子来介绍如何构建用户自己的DGL NN模块。
DGL NN模块的构造函数
构造函数完成以下几个任务:
- 设置选项。
- 注册可学习的参数或者子模块。
- 初始化参数。
import torch.nn as nn
from dgl.utils import expand_as_pair
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.activation = activation
在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。 对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。
除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type
)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型包括 mean
、 sum
、 max
和 min
。一些模块可能会使用更加复杂的聚合函数,比如 lstm
。
上面代码里的 norm
是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化:
h
𝑣
=
h
𝑣
/
‖
h
𝑣
‖
2
ℎ_𝑣=ℎ_𝑣/‖ℎ_𝑣‖2
hv=hv/‖hv‖2。
# 聚合类型:mean、pool、lstm、gcn
if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'pool', 'lstm']:
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linear
、 nn.LSTM
等。 构造函数的最后调用了 reset_parameters()
进行权重初始化。
def reset_parameters(self):
"""重新初始化可学习的参数"""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
编写DGL NN模块的forward函数
在NN模块中, forward()
函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比, DGL NN模块额外增加了1个参数 dgl.DGLGraph。forward()
函数的内容一般可以分为3项操作:
- 检测输入图对象是否符合规范。
- 消息传递和聚合。
- 聚合后,更新特征作为输出。
下文展示了SAGEConv示例中的 forward()
函数。
输入图对象的规范检测
def forward(self, graph, feat):
with graph.local_scope():
# 指定图类型,然后根据图类型扩展输入特征
feat_src, feat_dst = expand_as_pair(feat, graph)
forward()
函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。 比如在 GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。 当1个节点入度为0时, mailbox
将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是,在 SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来, forward()
函数的输出不会全为0。在这种情况下,无需进行此类检验。
DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(1.5 异构图)和子图块(第6章:在大图上的随机(批次)训练)。
SAGEConv的数学公式如下:
源节点特征 feat_src
和目标节点特征 feat_dst
需要根据图类型被指定。 用于指定图类型并将 feat
扩展为 feat_src
和 feat_dst
的函数是 expand_as_pair()。 该函数的细节如下所示。
def expand_as_pair(input_, g=None):
if isinstance(input_, tuple):
# 二分图的情况
return input_
elif g is not None and g.is_block:
# 子图块的情况
if isinstance(input_, Mapping):
input_dst = {
k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
for k, v in input_.items()}
else:
input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
return input_, input_dst
else:
# 同构图的情况
return input_, input_
对于同构图上的全图训练,源节点和目标节点相同,它们都是图中的所有节点。
在异构图的情况下,图可以分为几个二分图,每种关系对应一个。关系表示为 (src_type, edge_type, dst_dtype)
。 当输入特征 feat
是1个元组时,图将会被视为二分图。元组中的第1个元素为源节点特征,第2个元素为目标节点特征。
在小批次训练中,计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(block
)。 在区块创建的阶段,dst nodes
位于节点列表的最前面。通过索引 [0:g.number_of_dst_nodes()]
可以找到 feat_dst
。
确定 feat_src
和 feat_dst
之后,以上3种图类型的计算方法是相同的。
消息传递和聚合
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# 除以入度
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE中gcn聚合不需要fc_self
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息传递均使用 update_all() API和 DGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的消息传递代码 里所介绍的性能优化。
聚合后,更新特征作为输出
# 激活函数
if self.activation is not None:
rst = self.activation(rst)
# 归一化
if self.norm is not None:
rst = self.norm(rst)
return rst
forward()
函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。
异构图上的GraphConv模块
DGL提供了 HeteroGraphConv,用于定义异构图上GNN模块。 实现逻辑与消息传递级别的API multi_update_all() 相同,它包括:
- 每个关系上的DGL NN模块。
- 聚合来自不同关系上的结果。
其数学定义为:
HeteroGraphConv的实现逻辑
import torch.nn as nn
class HeteroGraphConv(nn.Module):
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mods = nn.ModuleDict(mods)
if isinstance(aggregate, str):
# 获取聚合函数的内部函数
self.agg_fn = get_aggregate_fn(aggregate)
else:
self.agg_fn = aggregate
异构图的卷积操作接受一个字典类型参数 mods
。这个字典的键为关系名,值为作用在该关系上NN模块对象。参数 aggregate
则指定了如何聚合来自不同关系的结果。
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
if mod_args is None:
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
除了输入图和输入张量,forward()
函数还使用2个额外的字典参数 mod_args
和 mod_kwargs
。 这2个字典与 self.mods
具有相同的键,值则为对应NN模块的自定义参数。
forward()
函数的输出结果也是一个字典类型的对象。其键为 nty
,其值为每个目标节点类型 nty
的输出张量的列表, 表示来自不同关系的计算结果。HeteroGraphConv
会对这个列表进一步聚合,并将结果返回给用户。
if g.is_block:
src_inputs = inputs
dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
src_inputs = dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.num_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
输入 g
可以是异构图或来自异构图的子图区块。和普通的NN模块一样,forward()
函数需要分别处理不同的输入图类型。
上述代码中的for循环为处理异构图计算的主要逻辑。首先我们遍历图中所有的关系(通过调用 canonical_etypes
)。 通过关系名,我们可以使用g[ stype, etype, dtype ]
的语法将只包含该关系的子图( rel_graph
)抽取出来。 对于二分图,输入特征将被组织为元组 (src_inputs[stype], dst_inputs[dtype])
。 接着调用用户预先注册在该关系上的NN模块,并将结果保存在outputs
字典中。
rsts = {}
for nty, alist in outputs.items():
if len(alist) != 0:
rsts[nty] = self.agg_fn(alist, nty)
最后,HeteroGraphConv
会调用用户注册的 self.agg_fn
函数聚合来自多个关系的结果。 读者可以在API文档中找到 :class:~dgl.nn.pytorch.HeteroGraphConv 的示例。