文章目录
背景
图神经网络的局限性:训练效率与可扩展性。
-
基于SGD的图神经网络的训练方法,随着图神经网络层数增加,计算成本呈指数增长;
-
保存整个图的信息和每一层每个节点的表征到内存(显存)而消耗巨大内存(显存)空间;
-
“邻居爆炸(Neighbor Explosion)”:在信息传递网络中,每一层的都需要将信息从中心点传给周围的邻居,经过很多层后,该信息到达的节点数指数增长。
解决方案:邻接点采样,每一层的信息传递无需用到所有邻接点,以减少计算量和内存需求。
- 点采样:GraphSAGE–>PinSAGE、VR-GCN
- 层采样:FastGCN–>ASGCN
- 图采样:Cluster-GCN、GraphSAINT
Cluster-GCN方法
简介
为了解决普通训练方法无法训练超大图的问题,Cluster-GCN论文提出:
- 利用图节点聚类算法将一个图的节点划分为 c c c个簇,每一次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练;
- 由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率;
- 每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会有batch内类别分布偏差过大的问题;
- 基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,进而可以达到更高的精度。
基本方法
mini-batch SGD:采用mini-batch SGD的方式训练图神经网络,可提高训练速度并减少内存(显存)需求。这是因为,在参数更新中,SGD不需要计算完整梯度,而只需要基于mini-batch计算部分梯度。使用 B ⊆ [ N ] \mathcal{B} \subseteq[N] B⊆[N]来表示一个batch,其大小为 b = ∣ B ∣ b=|\mathcal{B}| b=∣B∣。SGD的每一步都将计算梯度估计值 1 ∣ B ∣ ∑ i ∈ B ∇ loss ( y i , z i ( L ) ) \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \nabla \operatorname{loss}\left(y_{i}, z_{i}^{(L)}\right) ∣B∣1∑i∈B∇loss(yi,zi(L))来进行参数更新。
表征利用率:在训练过程中,如果节点 i i i在 l l l层的表征 z i ( l ) z_{i}^{(l)} zi(l)被计算并在 l + 1 l+1 l+1层的表征计算中被重复使用 u u u次,那么我们说 z i ( l ) z_{i}^{(l)} zi(l)的表征利用率为 u u u。表征利用率越大,计算效率越高。
由于在每个batch中计算一组节点(记为 B \mathcal{B} B)从第 1 1 1层到第 L L L层的表征,且每一层的计算都使用相同的子图 A B , B A_{\mathcal{B}, \mathcal{B}} AB,B( B \mathcal{B} B内部的边),所以表征利用率就是这个batch内边的数量,记为 ∥ A B , B ∥ 0 \left\|A_{\mathcal{B}, \mathcal{B}}\right\|_{0} ∥AB,B∥0。为了最大限度地提高表征利用率,理想的划分batch的结果是,batch内的边尽可能多,batch之间的边尽可能少,与图节点聚类思想不谋而合。
综上,Cluster-GCN基于图节点聚类,在聚类结果上使用SGD训练网络。具体方法如下:
对于一个图
G
G
G,我们将其节点划分为
c
c
c个簇:
V
=
[
V
1
,
⋯
V
c
]
\mathcal{V}=\left[\mathcal{V}_{1}, \cdots \mathcal{V}_{c}\right]
V=[V1,⋯Vc],其中
V
t
\mathcal{V}_{t}
Vt由第
t
t
t个簇中的节点组成,对应的我们有
c
c
c个子图:
KaTeX parse error: Undefined control sequence: \notag at position 161: …right\}\right] \̲n̲o̲t̲a̲g̲ ̲
其中
E
t
\mathcal{E}_{t}
Et只由
V
t
\mathcal{V}_{t}
Vt中的节点之间的边组成。经过节点重组,邻接矩阵被划分为大小为
c
2
c^{2}
c2的块矩阵,如下所示
A
=
A
ˉ
+
Δ
=
[
A
11
⋯
A
1
c
⋮
⋱
⋮
A
c
1
⋯
A
c
c
]
A=\bar{A}+\Delta=\left[\begin{array}{ccc} A_{11} & \cdots & A_{1 c} \\ \vdots & \ddots & \vdots \\ A_{c 1} & \cdots & A_{c c} \end{array}\right]
A=Aˉ+Δ=⎣⎢⎡A11⋮Ac1⋯⋱⋯A1c⋮Acc⎦⎥⎤
其中
A
ˉ
=
[
A
11
⋯
0
⋮
⋱
⋮
0
⋯
A
c
c
]
,
Δ
=
[
0
⋯
A
1
c
⋮
⋱
⋮
A
c
1
⋯
0
]
\bar{A}=\left[\begin{array}{ccc} A_{11} & \cdots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \cdots & A_{c c} \end{array}\right], \Delta=\left[\begin{array}{ccc} 0 & \cdots & A_{1 c} \\ \vdots & \ddots & \vdots \\ A_{c 1} & \cdots & 0 \end{array}\right]
Aˉ=⎣⎢⎡A11⋮0⋯⋱⋯0⋮Acc⎦⎥⎤,Δ=⎣⎢⎡0⋮Ac1⋯⋱⋯A1c⋮0⎦⎥⎤
其中,对角线上的块
A
t
t
A_{t t}
Att是大小为
∣
V
t
∣
×
∣
V
t
∣
\left|\mathcal{V}_{t}\right| \times\left|\mathcal{V}_{t}\right|
∣Vt∣×∣Vt∣的邻接矩阵,它由
G
t
G_{t}
Gt内部的边构成。
A
ˉ
\bar{A}
Aˉ是图
G
ˉ
\bar{G}
Gˉ的邻接矩阵。
A
s
t
A_{s t}
Ast由两个簇
V
s
\mathcal{V}_{s}
Vs和
V
t
\mathcal{V}_{t}
Vt之间的边构成。
Δ
\Delta
Δ是由
A
A
A的所有非对角线块组成的矩阵。同样,我们可以根据
[
V
1
,
⋯
,
V
c
]
\left[\mathcal{V}_{1}, \cdots, \mathcal{V}_{c}\right]
[V1,⋯,Vc]划分节点表征矩阵
X
X
X和类别向量
Y
Y
Y,得到
[
X
1
,
⋯
,
X
c
]
\left[X_{1}, \cdots, X_{c}\right]
[X1,⋯,Xc]和
[
Y
1
,
⋯
,
Y
c
]
\left[Y_{1}, \cdots, Y_{c}\right]
[Y1,⋯,Yc],其中
X
t
X_{t}
Xt和
Y
t
Y_{t}
Yt分别由
V
t
V_{t}
Vt中节点的表征和类别组成。
接下来用块对角线邻接矩阵
A
ˉ
\bar{A}
Aˉ去近似邻接矩阵
A
A
A,这样做的好处是,完整的损失函数可以根据batch分解成多个部分(Cluster)之和。以
A
ˉ
′
\bar{A}^{\prime}
Aˉ′表示归一化后的
A
ˉ
\bar{A}
Aˉ,最后一层节点表征矩阵可以做如下的分解:
Z
(
L
)
=
A
ˉ
′
σ
(
A
ˉ
′
σ
(
⋯
σ
(
A
ˉ
′
X
W
(
0
)
)
W
(
1
)
)
⋯
)
W
(
L
−
1
)
=
[
A
ˉ
11
′
σ
(
A
ˉ
11
′
σ
(
⋯
σ
(
A
ˉ
11
′
X
1
W
(
0
)
)
W
(
1
)
)
⋯
)
W
(
L
−
1
)
⋮
A
ˉ
c
c
′
σ
(
A
ˉ
c
c
′
σ
(
⋯
σ
(
A
ˉ
c
c
′
X
c
W
(
0
)
)
W
(
1
)
)
⋯
)
W
(
L
−
1
)
]
\begin{aligned} Z^{(L)} &=\bar{A}^{\prime} \sigma\left(\bar{A}^{\prime} \sigma\left(\cdots \sigma\left(\bar{A}^{\prime} X W^{(0)}\right) W^{(1)}\right) \cdots\right) W^{(L-1)} \\ &=\left[\begin{array}{c} \bar{A}_{11}^{\prime} \sigma\left(\bar{A}_{11}^{\prime} \sigma\left(\cdots \sigma\left(\bar{A}_{11}^{\prime} X_{1} W^{(0)}\right) W^{(1)}\right) \cdots\right) W^{(L-1)} \\ \vdots \\ \bar{A}_{c c}^{\prime} \sigma\left(\bar{A}_{c c}^{\prime} \sigma\left(\cdots \sigma\left(\bar{A}_{c c}^{\prime} X_{c} W^{(0)}\right) W^{(1)}\right) \cdots\right) W^{(L-1)} \end{array}\right] \end{aligned}
Z(L)=Aˉ′σ(Aˉ′σ(⋯σ(Aˉ′XW(0))W(1))⋯)W(L−1)=⎣⎢⎡Aˉ11′σ(Aˉ11′σ(⋯σ(Aˉ11′X1W(0))W(1))⋯)W(L−1)⋮Aˉcc′σ(Aˉcc′σ(⋯σ(Aˉcc′XcW(0))W(1))⋯)W(L−1)⎦⎥⎤
由于
A
ˉ
\bar{A}
Aˉ是块对角形式(
A
ˉ
t
t
′
\bar{A}_{t t}^{\prime}
Aˉtt′是
A
ˉ
′
\bar{A}^{\prime}
Aˉ′的对角线上的块),于是损失函数可以分解为
$$
\mathcal{L}{\bar{A}^{\prime}}=\sum{t} \frac{\left|\mathcal{V}{t}\right|}{N} \mathcal{L}{\bar{A}_{t t}^{\prime}} \
\mathcal{L}{\bar{A}{t t}^{\prime}}=\frac{1}{\left|\mathcal{V}{t}\right|} \sum{i \in \mathcal{V}{t}} \operatorname{loss}\left(y{i}, z_{i}^{(L)}\right)
$$
基于上述分析,在训练中的每一步,Cluster-GCN首先采样一个簇
V
t
\mathcal{V}_{t}
Vt,然后根据
L
A
ˉ
′
t
t
\mathcal{L}_{{\bar{A}^{\prime}}_{tt}}
LAˉ′tt的梯度进行参数更新,该过程只需要用到子图
A
t
t
A_{t t}
Att,
X
t
X_{t}
Xt,
Y
t
Y_{t}
Yt以及神经网络权重矩阵
{
W
(
l
)
}
l
=
1
L
\left\{W^{(l)}\right\}_{l=1}^{L}
{W(l)}l=1L。 实际中,主要的计算开销在神经网络前向过程中的矩阵乘法运算
Z
(
L
)
Z^{(L)}
Z(L)和梯度反向传播,这比以前SGD中使用的邻域搜索过程更容易实现。
通过下图原始GCN(左)与Cluster-GCN(右)的节点消息传递过程可以看到,Cluster-GCN可以避免大量的邻域搜索,消息传递集中在每个簇中的邻居上。
Cluster-GCN时间复杂度:由于簇 V t \mathcal{V}_{t} Vt中每个节点只连接到该簇内部的节点,节点的邻域扩展不需要在簇外进行。每个batch的计算将纯粹是矩阵乘积运算( A ˉ t t ′ X t ( l ) W ( l ) \bar{A}_{t t}^{\prime} X_{t}^{(l)} W^{(l)} Aˉtt′Xt(l)W(l))和一些对元素的操作(ReLU),因此,每个batch的总体时间复杂度为 O ( ∥ A t t ∥ 0 F + b F 2 ) O\left(\left\|A_{t t}\right\|_{0} F+ b F^{2}\right) O(∥Att∥0F+bF2),每个epoch的总体时间复杂度为 O ( ∥ A ∥ 0 F + N F 2 ) O\left(\|A\|_{0} F+N F^{2}\right) O(∥A∥0F+NF2)。平均来说,每个batch只需要计算 O ( b L ) O(b L) O(bL)的表征,这是线性的,而不是指数级的。
Cluster-GCN空间复杂度:在每个batch中,只需要在每一层中存储 b b b个节点的表征,产生用于存储表征的内存(显存)开销为 O ( b L F ) O(b L F) O(bLF)。此外,该算法只需加载子图到内存(显存)中,而不是完整的图。
Cluster-GCN改进:随机多簇法
存在问题:
- 图被分割后,一些边(块邻接矩阵 A \mathcal{A} A中的 Δ \Delta Δ部分)被移除,性能可能因此会受到影响;
- 图聚类算法倾向于将相似的节点聚集在一起,导致单个簇中节点的类别分布可能与原始数据集不同,使得对梯度的估计有偏差。
**解决办法:**随机多簇法。
- 首先将图划分为 p p p个簇 V 1 , ⋯ , V p \mathcal{V}_{1}, \cdots, \mathcal{V}_{p} V1,⋯,Vp, p p p是一个较大的值;
- 在构建一个batch时,随机选择的 q q q个簇( t 1 , … , t q t_{1}, \ldots, t_{q} t1,…,tq)进行合并,减少分布差异性,由此得到的batch包含节点 { V t 1 ∪ ⋯ ∪ V t q } \left\{\mathcal{V}_{t_{1}} \cup \cdots \cup \mathcal{V}_{t_{q}}\right\} {Vt1∪⋯∪Vtq} 、簇内边 { A i i ∣ i ∈ t 1 , … , t q } \left\{A_{i i} \mid i \in t_{1}, \ldots, t_{q}\right\} {Aii∣i∈t1,…,tq}和簇间边 { A i j ∣ i , j ∈ t 1 , … , t q } \left\{A_{i j} \mid i, j \in t_{1}, \ldots, t_{q}\right\} {Aij∣i,j∈t1,…,tq}。
随机多簇法的优点:
- 不会丢失簇间的边(部分解决问题1)
- 不会有很大的batch内类别分布的偏差(部分解决问题2)
- 不同的epoch使用的batch不同,降低梯度估计的偏差(部分解决问题2)
改进后的Cluster-GCN算法流程图如下:
扩展:深层GCN的设计与训练
存在问题:
- 过去在小规模图上进行GCN的研究,产生这样的观点:更深的GCN,即增加更多的层对最终结果影响不大;
- 加深GCN神经网络层数后,训练变得很困难,因为层数多了之后前面的信息可能无法传到后面;
- 原始GCN中每个节点都聚合邻接节点在上一层的表征。然而,在深层GCN里该策略可能不适合,因为它没有考虑到层数的问题。直观地说,近距离的邻接节点应该比远距离的的邻接节点贡献更大。
解决办法:放大GCN每一层中使用的邻接矩阵 A A A的对角线部分。通过这种方式,在GCN的每一层的聚合中对来自上一层的表征赋予更大的权重。
-
方法1: X ( l + 1 ) = σ ( ( A ′ + I ) X ( l ) W ( l ) ) X^{(l+1)}=\sigma\left(\left(A^{\prime}+I\right) X^{(l)} W^{(l)}\right) X(l+1)=σ((A′+I)X(l)W(l))
- 该方法仍有以下问题需要改进:1)对所有节点使用相同的权重而未考虑其邻居的数量;2)当使用更多的层时,数值会呈指数级增长。
-
方法2:首先给原始的 A A A添加一个单位矩阵 I I I,并进行归一化处理,然后对角增强。
X ( l + 1 ) = σ ( ( A ~ + λ diag ( A ~ ) ) X ( l ) W ( l ) ) 其 中 , A ~ = ( D + I ) − 1 ( A + I ) X^{(l+1)}=\sigma\left((\tilde{A}+\lambda \operatorname{diag}(\tilde{A})) X^{(l)} W^{(l)}\right) \\ 其中,\tilde{A}=(D+I)^{-1}(A+I) X(l+1)=σ((A~+λdiag(A~))X(l)W(l))其中,A~=(D+I)−1(A+I)
Cluster-GCN实践
数据集采集与预处理
from torch_geometric.datasets import Reddit
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler
# 数据集获取与分析
dataset = Reddit('../dataset/Reddit')
data = dataset[0]
print(dataset.num_classes) #41
print(data.num_nodes) #232965
print(data.num_edges) #114615873
print(data.num_features) #602
# 图节点聚类与数据加载器生成
cluster_data = ClusterData(data, num_parts=1500, recursive=False, save_dir=dataset.processed_dir)
train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=8) # 此数据加载器遵循Cluster-GCN提出的方法:图节点被聚类划分成多个簇,加载器返回的一个batch由多个簇组成
subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, shuffle=False, num_workers=8) # 此数据加载器不对图节点聚类:计算一个batch中的节点的表征需要计算该batch中的所有节点L跳内的邻居节点
-
ClusterData
:继承自torch.utils.data.Dataset
CLASS ClusterData(data, num_parts: int, recursive: bool = False, save_dir: Optional[str] = None, log: bool = True) # 将大图分成若干子图,即实现Cluster-GCN中的图节点聚类 # data (torch_geometric.data.Data):全图数据对象 # num_parts:聚类类别数 # recursive:设置为True,使用多层递归二分法,否则使用多层k路分区法 # save_dir:子图数据存储文件地址
-
ClusterLoader
:继承自torch.utils.data.DataLoader
CLASS ClusterLoader(cluster_data, **kwargs) # Cluster-GCN中的改进随机多簇法实现 # cluster_data (torch_geometric.data.ClusterData):分簇图数据对象 # **kwargs:设置batch_size、shuffle、drop_last、num_workers等参数
-
NeighborSampler
:继承自torch.utils.data.DataLoader
CLASS NeighborSampler(edge_index, sizes, node_idx, num_nodes, return_e_id, transform, **kwargs) # 论文“Inductive Representation Learning on Large Graphs”中的邻居采样器实现 # edge_index:邻接矩阵索引 # sizes ([int]):每一层节点的邻居采样数。如果设置为 size[l] = -1,表示对所有邻居节点采样 # node_idx:mini batch包含节点,如果设置为None,则将考虑所有节点 # num_nodes:图中节点数 # return_e_id (bool):如果设置为 False,则不会返回采样边的原始边索引,仅适用于没有边特征的图数据,默认为True # transform:数据转换函数 # **kwargs:设置batch_size、shuffle、drop_last、num_workers等参数
Cluster-GCN的构建、训练与测试
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from tqdm import tqdm
from torch_geometric.nn import SAGEConv
# Cluster-GCN构建
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Net, self).__init__()
self.convs = ModuleList(
[SAGEConv(in_channels, 128),
SAGEConv(128, out_channels)])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return F.log_softmax(x, dim=-1)
def inference(self, x_all):
pbar = tqdm(total=x_all.size(0) * len(self.convs))
pbar.set_description('Evaluating')
for i, conv in enumerate(self.convs):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = conv((x, x_target), edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
def train():
model.train()
total_loss = total_nodes = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
nodes = batch.train_mask.sum().item()
total_loss += loss.item() * nodes
total_nodes += nodes
return total_loss / total_nodes
@torch.no_grad()
def test():
model.eval()
out = model.inference(data.x) #在全图上执行推理测试
y_pred = out.argmax(dim=-1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = y_pred[mask].eq(data.y[mask]).sum().item()
accs.append(correct / mask.sum().item())
return accs
for epoch in range(1, 31):
loss = train()
if epoch % 5 == 0:
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, test: {test_acc:.4f}')
else:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
-
ModuleList
:是一个储存不同module,并自动将每个module的parameters添加到网络中的容器。通过extend
,append
等操作可以把任意nn.Module
的子类 (比如nn.Conv2d
,nn.Linear
等) 加到这个list
里面。nn.Sequential
与nn.ModuleList
的区别:nn.Sequential
内部实现了forward
函数,因此可以不用写forward
函数;而nn.ModuleList
则没有实现内部forward函数;nn.Sequential
可以使用OrderedDict
对每层进行命名;nn.Sequential
里面的模块按照顺序进行排列的,所以必须确保前一个模块的输出大小和下一个模块的输入大小是一致的;而nn.ModuleList
并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言。
from torch.nn import ModuleList, Sequential class netML(nn.Module): def __init__(self): super(netML, self).__init__() self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)]) def forward(self, x): x = self.linears[2](x) x = self.linears[0](x) x = self.linears[1](x) return x class netS(nn.Module): def __init__(self): super(netS, self).__init__() self.block = nn.Sequential(nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU()) def forward(self, x): x = self.block(x) return x
参考
- DataWhale GNN组队学习
- 《图神经网络:基础与前沿》 马腾飞 / 编著
- pytorch学习笔记 torchnn.ModuleList