pytorch 训练过程acc_【图节点分类】10分钟就学会的图节点分类教程,基于pytorch和dgl...

本文介绍了如何使用PyTorch和DGL库进行图节点分类任务,涉及构建GNN模型、数据集分析以及模型训练与评估。文章通过实例展示了使用SAGEConv模块进行多轮消息传递,并提供了数据可视化和测试集评估的步骤。
摘要由CSDN通过智能技术生成

af1b8a0bdba0d5ee60b530b659622655.png

图神经网络中最流行和广泛采用的任务之一就是节点分类,其中训练集/验证集/测试集中的每个节点从一组预定义的类别中分配一个真实类别。

为了对节点进行分类,图神经网络利用节点自身的特征,以及相邻节点和边的特征进行消息传递。消息传递可以重复多次,以聚合来自更大范围的邻居节点的信息。

dgl框架为我们提供了一些内置的图卷积模块,可以执行一轮的消息传递。

在本文中,我们使用dgl.nn.pytorch的SAGEConv模块,该模块来自这篇论文GraphSAGE:Inductive Representation Learning on Large Graphs

通常对于图上的深度学习模型,我们需要一个多层图神经网络,在这里我们进行多轮的消息传递。这可以通过如下方式堆叠图卷积模块来实现。

1 构造GNN模型

先导入必要包(本文dgl 版本为 0.5.2)

import dgl.nn as dglnn 
import dgl 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from dgl.data import * 

构造一个两层的gnn模型

class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, dropout=0.2):
        super().__init__()
        self.conv1 = dglnn.SAGEConv( 
            in_feats=in_feats, out_feats=hid_feats, feat_drop=0.2, aggregator_type='gcn')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, feat_drop=0.2, aggregator_type='mean')
        self.dropout =  nn.Dropout(dropout)
    
    def forward(self, graph, inputs):
        # inputs 是节点的特征 [N, in_feas]
        h = self.conv1(graph, inputs)
        h = self.dropout(F.relu(h))
        h = self.conv2(graph, h)
        return h 

注意,我们不仅可以将上面的模型用于节点分类,还可以获取节点的特征表示为了其他下游任务,如边分类/回归、链接预测或图分类。

2 数据集与数据分析

dataset = CoraGraphDataset() # Cora citation network dataset
graph = dataset[0]
graph = dgl.remove_self_loop(graph)  # 消除自环
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1) 

6bdc6ca89da2a5f0c738c90aa04b045a.png
print("图的节点数和边数: ", graph.num_nodes(), graph.num_edges())
print("训练集节点数:", train_mask.sum().item())
print("验证集集节点数:", valid_mask.sum().item()) 
print("测试集节点数:", test_mask.sum().item())
print("节点特征维数:", n_features)
print("标签类目数:", n_labels)

f8b1e9031a9c9ece71ac4d5a3617c60e.png

随机抽200个节点并画图展示:

import networkx as nx 
import numpy as np
import matplotlib.pyplot as plt 

G = graph.to_networkx() 
res = np.random.randint(0, high=G.number_of_nodes(), size=(200))

k = G.subgraph(res) 
pos = nx.spring_layout(k)

plt.figure()
nx.draw(k, pos=pos, node_size=8 )
plt.savefig('cora.jpg', dpi=600)
plt.show()

5bd7cab1edb53535a73c6b1e16deff26.png

3 训练模型与评估

def evaluate(model, graph, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

model = SAGE(in_feats=n_features, hid_feats=128, out_feats=n_labels) 
opt = torch.optim.Adam(model.parameters())

# 开始训练
best_val_acc = 0
for epoch in range(200): 
    print('Epoch {}'.format(epoch))
    model.train()
    # 用所有的节点进行前向传播
    logits = model(graph, node_features)
    # 计算损失
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    # 计算验证集accuracy
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    print('loss = {:.4f}'.format(loss.item()))
    if acc > best_val_acc:
        best_val_acc = acc
        torch.save(model.state_dict(), 'save_model/best_model.pth')
    print("current val acc = {}, best val acc = {}".format(acc, best_val_acc))

e759b9a5222ed51c78313f37f4aabdf3.png

测试集评估

model.load_state_dict(torch.load("save_model/best_model.pth"))
acc = evaluate(model, graph, node_features, node_labels, test_mask)
print("test accuracy: ", acc)

078c095797fba90edaae28f051032aa7.png

完结:-) 觉得有用记得双击点赞呀!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值