图神经网络中最流行和广泛采用的任务之一就是节点分类,其中训练集/验证集/测试集中的每个节点从一组预定义的类别中分配一个真实类别。
为了对节点进行分类,图神经网络利用节点自身的特征,以及相邻节点和边的特征进行消息传递。消息传递可以重复多次,以聚合来自更大范围的邻居节点的信息。
dgl框架为我们提供了一些内置的图卷积模块,可以执行一轮的消息传递。
在本文中,我们使用dgl.nn.pytorch的SAGEConv模块,该模块来自这篇论文GraphSAGE:Inductive Representation Learning on Large Graphs
通常对于图上的深度学习模型,我们需要一个多层图神经网络,在这里我们进行多轮的消息传递。这可以通过如下方式堆叠图卷积模块来实现。
1 构造GNN模型
先导入必要包(本文dgl 版本为 0.5.2)
import dgl.nn as dglnn
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import *
构造一个两层的gnn模型
class SAGE(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, dropout=0.2):
super().__init__()
self.conv1 = dglnn.SAGEConv(
in_feats=in_feats, out_feats=hid_feats, feat_drop=0.2, aggregator_type='gcn')
self.conv2 = dglnn.SAGEConv(
in_feats=hid_feats, out_feats=out_feats, feat_drop=0.2, aggregator_type='mean')
self.dropout = nn.Dropout(dropout)
def forward(self, graph, inputs):
# inputs 是节点的特征 [N, in_feas]
h = self.conv1(graph, inputs)
h = self.dropout(F.relu(h))
h = self.conv2(graph, h)
return h
注意,我们不仅可以将上面的模型用于节点分类,还可以获取节点的特征表示为了其他下游任务,如边分类/回归、链接预测或图分类。
2 数据集与数据分析
dataset = CoraGraphDataset() # Cora citation network dataset
graph = dataset[0]
graph = dgl.remove_self_loop(graph) # 消除自环
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)
print("图的节点数和边数: ", graph.num_nodes(), graph.num_edges())
print("训练集节点数:", train_mask.sum().item())
print("验证集集节点数:", valid_mask.sum().item())
print("测试集节点数:", test_mask.sum().item())
print("节点特征维数:", n_features)
print("标签类目数:", n_labels)
随机抽200个节点并画图展示:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
G = graph.to_networkx()
res = np.random.randint(0, high=G.number_of_nodes(), size=(200))
k = G.subgraph(res)
pos = nx.spring_layout(k)
plt.figure()
nx.draw(k, pos=pos, node_size=8 )
plt.savefig('cora.jpg', dpi=600)
plt.show()
3 训练模型与评估
def evaluate(model, graph, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(graph, features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
model = SAGE(in_feats=n_features, hid_feats=128, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())
# 开始训练
best_val_acc = 0
for epoch in range(200):
print('Epoch {}'.format(epoch))
model.train()
# 用所有的节点进行前向传播
logits = model(graph, node_features)
# 计算损失
loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
# 计算验证集accuracy
acc = evaluate(model, graph, node_features, node_labels, valid_mask)
# backward propagation
opt.zero_grad()
loss.backward()
opt.step()
print('loss = {:.4f}'.format(loss.item()))
if acc > best_val_acc:
best_val_acc = acc
torch.save(model.state_dict(), 'save_model/best_model.pth')
print("current val acc = {}, best val acc = {}".format(acc, best_val_acc))
测试集评估
model.load_state_dict(torch.load("save_model/best_model.pth"))
acc = evaluate(model, graph, node_features, node_labels, test_mask)
print("test accuracy: ", acc)
完结:-) 觉得有用记得双击点赞呀!