超大图上的节点表征学习
我们使用内存数据集将图神经网络应用于许多节点或边的预测任务,然而在实际的工作中面临着超大图上进行图神经网络的训练,巨大的内存(显存)消耗问题。论文Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Network提出了一种新的图神经网络的训练方法。
Cluster-GCN方法简介
为了解决普通训练方法无法训练超大图的问题,Cluster-GCN论文提出:
- 利用图节点聚类算法将一个图的节点划分为 c c c个簇,每次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练。
- 由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率。
- 每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会有batch内类别分布偏差过大的问题。
- 基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,进而可以达到更高的精度。
以往训练的瓶颈
以往的训练方法需要同时计算所有节点的表征以及训练集中所有节点的损失产生的梯度,因此需要非常巨大的计算及内存(显存)开销,由于神经网络在每个epoch中只更新一次,所以训练需要更多的epoch才能达到收敛。
采用mini-batch SGD的方式训练,可提高网络的训练速度并减少内存(显存)需求。SGD不需要计算完整梯度,而只需要计算部分梯度。
简单的Cluster-GCN方法
Cluster-GCN方法是通过找到一种将节点分成多个batch的方式,对应地将图划分成多个子图,使得表征利用率最大。
在每个batch中计算一组节点(记为 B \mathcal{B} B)从第 1 1 1层到第 L L L层的表征。由于图神经网络每一层的计算都使用相同的子图 A B , B A_{\mathcal{B}, \mathcal{B}} AB,B( B \mathcal{B} B内部的边),所以表征利用率就是这个batch内边的数量,记为 ∥ A B , B ∥ 0 \left\|A_{\mathcal{B}, \mathcal{B}}\right\|_{0} ∥AB,B∥0。因此,为了最大限度地提高表征利用率,理想的划分batch的结果是,batch内的边尽可能多,batch之间的边尽可能少。基于这一点,我们将SGD图神经网络训练的效率与图聚类算法联系起来。
图神经网络的构建
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Net, self).__init__()
self.convs = ModuleList(
[SAGEConv(in_channels, 128),
SAGEConv(128, out_channels)])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return F.log_softmax(x, dim=-1)
def inference(self, x_all):
pbar = tqdm(total=x_all.size(0) * len(self.convs))
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
for i, conv in enumerate(self.convs):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = conv((x, x_target), edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
训练、验证与测试
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
def train():
model.train()
total_loss = total_nodes = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
nodes = batch.train_mask.sum().item()
total_loss += loss.item() * nodes
total_nodes += nodes
return total_loss / total_nodes
@torch.no_grad()
def test(): # Inference should be performed on the full graph.
model.eval()
out = model.inference(data.x)
y_pred = out.argmax(dim=-1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = y_pred[mask].eq(data.y[mask]).sum().item()
accs.append(correct / mask.sum().item())
return accs
for epoch in range(1, 31):
loss = train()
if epoch % 5 == 0:
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, test: {test_acc:.4f}')
else:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')