超大图上的节点表征学习
参考资料:
论文Cluster-GCN: An Efficient Algorithm for Training Deep andLarge Graph Convolutional Network
team-learning-nlp
图神经网络在超大图上进行图神经网络的训练仍然具有挑战。普通的基于SGD的图神经网络的训练方法存在两个问题
- 随着图神经网络层数增加,计算成本呈指数增长的问题,
- 保存整个图的信息和每一层每个节点的表征到内存(显存)而消耗巨大内存(显存)空间的问题。
虽然已经有一些论文提出了无需保存整个图的信息和每一层每个节点的表征到GPU内存(显存)的方法,但这些方法可能会损失预测精度或者对提高内存的利用率并不明显。于是论文Cluster-GCN: An Efficient Algorithm for Training Deep andLarge Graph Convolutional Network提出了一种新的图神经网络的训练方法。在此篇文章中,我们将首先对Cluster-GCN论文中提出的方法(后文中简称为Cluster-GCN方法)做简单概括,接着深入分析超大图上的节点表征学习面临的挑战,最后对Cluster-GCN方法做深入分析。
为了解决普通训练方法无法训练超大图的问题,Cluster-GCN论文提出:
- 利用图节点聚类算法将一个图的节点划分为个簇,每一次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练。
- 由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率。
- 每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会
有batch内类别分布偏差过大的问题。 - 基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网
络,进而可以达到更高的精度。
简单的Cluster-GCN方法
为了最大限度地提高表征利用率,理想的划分batch的结果是,batch内的边尽可能多,batch之间的边尽可能少。在训练的每一步中,Cluster-GCN首先采样一个簇,然后根据损失的梯度进行参数更新。
图节点聚类算法将图节点分成多个簇,划分结果是簇内边的数量远多于簇间边的数量。每个batch的表征利用率相当于簇内边的数量。直观地说,每个节点和它的邻接节点大部分情况下都位于同一个簇中,因此跳(L-hop)远的邻接节点大概率仍然在同一个簇中。
由于我们用块对角线近似邻接矩阵代替邻接矩阵,产生的误差与簇间的边的数量成正比,所以簇间的边越少越好。综上所述,使用图节点聚类算法对图节点划分多个簇的结果,正是我们希望得到的。
随机多分区
尽管简单Cluster-GCN方法可以做到较其他方法更低的计算和内存复杂度,但它仍
存在两个潜在问题:
- 图被分割后,一些边(公式(4)中的部分)被移除,性能可能因此会受到影响。
- 图聚类算法倾向于将相似的节点聚集在一起。因此,单个簇中节点的类别分布可能与原始数据集不同,导致对梯度的估计有偏差。
为了解决上述问题,Cluster-GCN论文提出了一种随机多簇方法,此方法首先将图划分为个簇,在构建一个batch时,不是只使用一个簇, 而是使用随机选择的q个簇,得到的batch 包含节点、簇内边和簇间边。此方法的好处有,1)不会丢失簇间的边,2)不会有很大的batch内类别分布的偏差,3)以及不同的epoch使用的batch不同,这可以降低梯度估计的偏差。
训练深层GCNs的问题
在原始的
GCN的设置里,每个节点都聚合邻接节点在上一层的表征。然而,在深层GCN的设置里,该策略可能不适合,因为它没有考虑到层数的问题。直观地说,近距离的邻接节点应该比远距离的的邻接节点贡献更大。因此,Cluster-GCN提出一种技术来更好地解决这个问题。其主要思想是放大GCN每一层中使用的邻接矩阵的对角线部分。通过这种方式,我们在GCN的每一层的聚合中对来自上一层的表征赋予更大的权重。
代码实践
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from tqdm import tqdm
from torch_geometric.datasets import Reddit, Reddit2
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler
from torch_geometric.nn import SAGEConv
# 需要下载 https://data.dgl.ai/dataset/reddit.zip 到 data/Reddit 文件夹下
dataset = Reddit('./data/Reddit')#加载数据
# dataset = Reddit2('data/Reddit2')
data = dataset[0]
# 图节点聚类与数据加载器生成
#ClusterData数据加载器遵循Cluster-GCN提出的方法,图节点被聚类划分成多个簇,此数据加载器返回的一个batch由多个簇组成。
cluster_data = ClusterData(data, num_parts=1500, recursive=False,
save_dir=dataset.processed_dir)
train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True,
num_workers=12)
# 使用NeighborSampler数据加载器不对图节点聚类,计算一个batch中的节点的表征需要计算该batch中的所有节点的距离从到的邻居节点
subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024,
shuffle=False, num_workers=12)
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Net, self).__init__()
self.convs = ModuleList(
[SAGEConv(in_channels, 128),
SAGEConv(128, out_channels)])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return F.log_softmax(x, dim=-1)
#inference方法应用于推理阶段
def inference(self, x_all):
pbar = tqdm(total=x_all.size(0) * len(self.convs))
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
for i, conv in enumerate(self.convs):
xs = []
for batch_size, n_id, adj in subgraph_loader:#为了获取更高的预测精度,所以使用subgraph_loader
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = conv((x, x_target), edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
# 训练、验证与测试
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
def train():
model.train()
total_loss = total_nodes = 0
#使用train_loader获取batch
for batch in train_loader:
batch = batch.to(device)
# 根据多个簇组成的batch进行神经网络的训练
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
nodes = batch.train_mask.sum().item()
total_loss += loss.item() * nodes
total_nodes += nodes
return total_loss / total_nodes
# 验证阶段,我们使用subgraph_loader,在计算一个节点的表征时会计算该节点的距离从到的邻接节点,这么做可以更好地测试神经网络的性能。
@torch.no_grad()
def test(): # Inference should be performed on the full graph.
model.eval()
out = model.inference(data.x)
y_pred = out.argmax(dim=-1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = y_pred[mask].eq(data.y[mask]).sum().item()
accs.append(correct / mask.sum().item())
return accs
for epoch in range(1, 31):
loss = train()
if epoch % 5 == 0:
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, test: {test_acc:.4f}')
else:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
Using exist file reddit.zip
Extracting data/Reddit/raw/reddit.zip
Processing...
Done!
Computing METIS partitioning...
Done!
Epoch: 01, Loss: 1.0812
Epoch: 02, Loss: 0.4512
Epoch: 03, Loss: 0.3842
Epoch: 04, Loss: 0.3508
Evaluating: 100%|██████████| 465930/465930 [01:08<00:00, 6778.44it/s]
Epoch: 05, Loss: 0.3430, Train: 0.9563, Val: 0.9512, test: 0.9506
Epoch: 06, Loss: 0.3169
Epoch: 07, Loss: 0.3021
Epoch: 08, Loss: 0.2969
Epoch: 09, Loss: 0.3109
Evaluating: 100%|██████████| 465930/465930 [01:08<00:00, 6752.90it/s]
Epoch: 10, Loss: 0.2952, Train: 0.9648, Val: 0.9563, test: 0.9538
Epoch: 11, Loss: 0.2737
Epoch: 12, Loss: 0.2801
Epoch: 13, Loss: 0.2662
Epoch: 14, Loss: 0.2616
Evaluating: 100%|██████████| 465930/465930 [01:09<00:00, 6713.43it/s]
Epoch: 15, Loss: 0.2615, Train: 0.9668, Val: 0.9551, test: 0.9534
Epoch: 16, Loss: 0.2600
Epoch: 17, Loss: 0.2759
Epoch: 18, Loss: 0.2532
Epoch: 19, Loss: 0.2550
Evaluating: 100%|██████████| 465930/465930 [01:10<00:00, 6591.34it/s]
Epoch: 20, Loss: 0.2522, Train: 0.9684, Val: 0.9542, test: 0.9518
Epoch: 21, Loss: 0.2575
Epoch: 22, Loss: 0.2846
Epoch: 23, Loss: 0.2461
Epoch: 24, Loss: 0.2506
Evaluating: 100%|██████████| 465930/465930 [01:06<00:00, 7054.68it/s]
Epoch: 25, Loss: 0.2723, Train: 0.9699, Val: 0.9541, test: 0.9530
Epoch: 26, Loss: 0.2445
Epoch: 27, Loss: 0.2475
Epoch: 28, Loss: 0.2395
Epoch: 29, Loss: 0.2312
Evaluating: 100%|██████████| 465930/465930 [01:08<00:00, 6808.91it/s]
Epoch: 30, Loss: 0.2413, Train: 0.9686, Val: 0.9502, test: 0.9494