Cluster-GCN
图神经网络已经成功地应用于许多节点或边的预测任务,然而,在超大图上进行图神经网络的训练仍然具有挑战。主要面临两个问题:
- 随着图神经网络层数增加,计算成本呈指数增长
- 保存整个图的信息和每一层每个节点的表征到内存(显存)而消耗巨大内存(显存)空间
当前两种主要的梯度下降方法:
-
Full-batch gradient descent
内存消耗大、时间复杂度低、难收敛 -
Mini-batch gradient descent
内存消耗小、时间复杂度高、较容易收敛
论文Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Network提出了一种新的图神经网络的训练方法,叫做Cluster-GCN
Cluster-GCN方法主要贡献:
- 利用图节点聚类算法将一个图的节点划分为 c c c个簇,每一次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练。
- 由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率。
- 每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会有batch内类别分布偏差过大的问题。
- 基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,进而可以达到更高的精度。
Cluster-GCN代码实战
1、获取数据集
from torch_geometric.datasets import Reddit
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler
dataset = Reddit('../dataset/Reddit')
data = dataset[0]
print(dataset.num_classes)
print(data.num_nodes)
print(data.num_edges)
print(data.num_features)
# 41
# 232965
# 114615873
# 602
2、图节点聚类与数据加载器生成
cluster_data = ClusterData(data, num_parts=1500, recursive=False, save_dir=dataset.processed_dir)
train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=12)
subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, shuffle=False, num_workers=12)
3、图神经网络的构建
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Net, self).__init__()
self.convs = ModuleList(
[SAGEConv(in_channels, 128),
SAGEConv(128, out_channels)])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return F.log_softmax(x, dim=-1)
def inference(self, x_all):
pbar = tqdm(total=x_all.size(0) * len(self.convs))
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
for i, conv in enumerate(self.convs):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = conv((x, x_target), edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
4、训练、验证与测试
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
def train():
model.train()
total_loss = total_nodes = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
nodes = batch.train_mask.sum().item()
total_loss += loss.item() * nodes
total_nodes += nodes
return total_loss / total_nodes
@torch.no_grad()
def test(): # Inference should be performed on the full graph.
model.eval()
out = model.inference(data.x)
y_pred = out.argmax(dim=-1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = y_pred[mask].eq(data.y[mask]).sum().item()
accs.append(correct / mask.sum().item())
return accs
for epoch in range(1, 31):
loss = train()
if epoch % 5 == 0:
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, test: {test_acc:.4f}')
else:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
num_parts=1500
代码执行结果如下
Downloading https://data.dgl.ai/dataset/reddit.zip
Extracting data/Reddit/raw/reddit.zip
Processing…
Done!
Computing METIS partitioning…
Done!
Epoch: 01, Loss: 1.1537
Epoch: 02, Loss: 0.4769
Epoch: 03, Loss: 0.3911
Epoch: 04, Loss: 0.3565
Evaluating: 100%|██████████| 465930/465930 [01:01<00:00, 7564.35it/s]
Epoch: 05, Loss: 0.3414, Train: 0.9585, Val: 0.9552, test: 0.9533
Epoch: 06, Loss: 0.3218
Epoch: 07, Loss: 0.3130
Epoch: 08, Loss: 0.3201
Epoch: 09, Loss: 0.3004
Evaluating: 100%|██████████| 465930/465930 [00:54<00:00, 8525.57it/s]
Epoch: 10, Loss: 0.2799, Train: 0.9655, Val: 0.9546, test: 0.9540
Epoch: 11, Loss: 0.2745
Epoch: 12, Loss: 0.2827
Epoch: 13, Loss: 0.2673
Epoch: 14, Loss: 0.2670
Evaluating: 100%|██████████| 465930/465930 [00:53<00:00, 8710.35it/s]
Epoch: 15, Loss: 0.2761, Train: 0.9659, Val: 0.9530, test: 0.9518
Epoch: 16, Loss: 0.2655
Epoch: 17, Loss: 0.2613
Epoch: 18, Loss: 0.2550
Epoch: 19, Loss: 0.2656
Evaluating: 100%|██████████| 465930/465930 [00:54<00:00, 8511.85it/s]
Epoch: 20, Loss: 0.2635, Train: 0.9646, Val: 0.9501, test: 0.9490
Epoch: 21, Loss: 0.2507
Epoch: 22, Loss: 0.2461
Epoch: 23, Loss: 0.2428
Epoch: 24, Loss: 0.2419
Evaluating: 100%|██████████| 465930/465930 [00:54<00:00, 8551.65it/s]
Epoch: 25, Loss: 0.2388, Train: 0.9703, Val: 0.9554, test: 0.9530
Epoch: 26, Loss: 0.2371
Epoch: 27, Loss: 0.2364
Epoch: 28, Loss: 0.2359
Epoch: 29, Loss: 0.2362
Evaluating: 100%|██████████| 465930/465930 [00:53<00:00, 8664.79it/s]
Epoch: 30, Loss: 0.2332, Train: 0.9690, Val: 0.9514, test: 0.9504
实验结果
尝试将数据集切分成不同数量的簇进行实验,然后观察结果并进行比较
实验选用num_parts = 500、1000、1500、2000,结果如下表所示:
num_parts | Epoch=5 | Epoch=10 | Epoch=15 | Epoch=20 | Epoch=25 | Epoch=30 |
---|---|---|---|---|---|---|
500 | Train: 0.9512, Val: 0.9517, test: 0.9505 | Train: 0.9596, Val: 0.9530, test: 0.9542 | Train: 0.9630, Val: 0.9546, test: 0.9531 | Train: 0.9705, Val: 0.9578, test: 0.9575 | Train: 0.9682, Val: 0.9546, test: 0.9519 | Train: 0.9735, Val: 0.9573, test: 0.9552 |
1000 | Train: 0.9557, Val: 0.9535, test: 0.9520 | Train: 0.9641, Val: 0.9560, test: 0.9546 | Train: 0.9682, Val: 0.9571, test: 0.9558 | Train: 0.9688, Val: 0.9540, test: 0.9544 | Train: 0.9560, Val: 0.9463, test: 0.9450 | Train: 0.9725, Val: 0.9527, test: 0.9538 |
1500 | 0.3414, Train: 0.9585, Val: 0.9552, test: 0.9533 | Train: 0.9655, Val: 0.9546, test: 0.9540 | Train: 0.9659, Val: 0.9530, test: 0.9518 | Train: 0.9646, Val: 0.9501, test: 0.9490 | Train: 0.9703, Val: 0.9554, test: 0.9530 | Train: 0.9690, Val: 0.9514, test: 0.9504 |
2000 | Train: 0.9588, Val: 0.9549, test: 0.9518 | Train: 0.9591, Val: 0.9508, test: 0.9501 | Train: 0.9587, Val: 0.9488, test: 0.9460 | Train: 0.9666, Val: 0.9515, test: 0.9499 | Train: 0.9638, Val: 0.9478, test: 0.9451 | Train: 0.9688, Val: 0.9519, test: 0.9506 |
最终发现在reddit数据集上,簇数量为500,epoch=20时效果最好
本文内容来自datawhale开源课程
参考资料:https://www.bilibili.com/video/BV1UE411b7PN