图神经网络框架DGL教程-第5章:训练图神经网络

更多图神经网络和深度学习内容请关注:
在这里插入图片描述

第5章:训练图神经网络

本章通过使用 第2章:消息传递范式 中介绍的消息传递方法和 第3章:构建图神经网络(GNN)模块 中介绍的图神经网络模块, 讲解了如何对小规模的图数据进行节点分类、边分类、链接预测和整图分类的图神经网络的训练。

本章假设用户的图以及所有的节点和边特征都能存进GPU。对于无法全部载入的情况,请参考用户指南的 第6章:在大图上的随机(批次)训练

后续章节的内容均假设用户已经准备好了图和节点/边的特征数据。如果用户希望使用DGL提供的数据集或其他兼容 DGLDataset 的数据(如 第4章:图数据处理管道 所述), 可以使用类似以下代码的方法获取单个图数据集的图数据。

from dgl.data import CiteseerGraphDataset

dataset = CiteseerGraphDataset()
graph = dataset[0]
Using backend: pytorch

Finished data loading and preprocessing.
  NumNodes: 3327
  NumEdges: 9228
  NumFeats: 3703
  NumClasses: 6
  NumTrainingSamples: 120
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.

异构图训练的样例数据

有时用户会想在异构图上进行图神经网络的训练。本章会以下面代码所创建的一个异构图为例,来演示如何进行节点分类、边分类和链接预测的训练。

这个 hetero_graph 异构图有以下这些边的类型(类似无向边):

  • ('user', 'follow', 'user')
  • ('user', 'followed-by', 'user')
  • ('user', 'click', 'item')
  • ('item', 'clicked-by', 'user')
  • ('user', 'dislike', 'item')
  • ('item', 'disliked-by', 'user')
import numpy as np
import torch
import dgl

n_users = 1000
n_items = 500
n_follows = 3000
n_clicks = 5000
n_dislikes = 500
n_hetero_features = 10
n_user_classes = 5
n_max_clicks = 10

follow_src = np.random.randint(0, n_users, n_follows)
follow_dst = np.random.randint(0, n_users, n_follows)
click_src = np.random.randint(0, n_users, n_clicks)
click_dst = np.random.randint(0, n_items, n_clicks)
dislike_src = np.random.randint(0, n_users, n_dislikes)
dislike_dst = np.random.randint(0, n_items, n_dislikes)

hetero_graph = dgl.heterograph({
    ('user', 'follow', 'user'): (follow_src, follow_dst),
    ('user', 'followed-by', 'user'): (follow_dst, follow_src),
    ('user', 'click', 'item'): (click_src, click_dst),
    ('item', 'clicked-by', 'user'): (click_dst, click_src),
    ('user', 'dislike', 'item'): (dislike_src, dislike_dst),
    ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)
})

hetero_graph.nodes["user"].data["feature"] = torch.randn(n_users, n_hetero_features)
hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)
hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))
hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()
# 在user类型的节点和click类型的边上随机生成训练集的掩码
hetero_graph.nodes['user'].data["train_mask"] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
hetero_graph.edges["click"].data["train_mask"] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)

5.1 节点分类/回归

对于图神经网络来说,最常见和被广泛使用的任务之一就是节点分类。 图数据中的训练、验证和测试集中的每个节点都具有从一组预定义的类别中分配的一个类别,即正确的标注。 节点回归任务也类似,训练、验证和测试集中的每个节点都被标注了一个正确的数字。

概述

为了对节点进行分类,图神经网络执行了 第2章:消息传递范式 中介绍的消息传递机制,利用节点自身的特征和其邻节点及边的特征来计算节点的隐藏表示。 消息传递可以重复多轮,以利用更大范围的邻居信息。

编写神经网络模型

DGL提供了一些内置的图卷积模块,可以完成一轮消息传递计算。 本章中选择 dgl.nn.pytorch.SAGEConv 作为演示的样例代码(针对MXNet和PyTorch及TensorFlow后端有对应的模块), 它是GraphSAGE模型中使用的图卷积模块。

对于图上的深度学习模型,通常需要一个多层的图神经网络,并在这个网络中要进行多轮的信息传递。 可以通过堆叠图卷积模块来实现这种网络架构,具体如下所示。

# 构建一个2层的GNN模型
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        # 实例化SAGEConve,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregator_type是聚合函数的类型
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs是节点的特征
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h

请注意,这个模型不仅可以做节点分类,还可以为其他下游任务获取隐藏节点表示,如: 5.2 边分类/回归5.3 链接预测5.4 整图分类

关于DGL内置图卷积模块的完整列表,读者可以参考 dgl.nn

有关DGL神经网络模块如何工作,以及如何编写一个自定义的带有消息传递的GNN模块的更多细节,请参考 第3章:构建图神经网络(GNN)模块 中的例子

模型的训练

全图(使用所有的节点和边的特征)上的训练只需要使用上面定义的模型进行前向传播计算,并通过在训练节点上比较预测和真实标签来计算损失,从而完成后向传播。

本节使用DGL内置的数据集 dgl.data.CiteseerGraphDataset 来展示模型的训练。 节点特征和标签存储在其图上,训练、验证和测试的分割也以布尔掩码的形式存储在图上。这与在 第4章:图数据处理管道 中的做法类似。

from dgl.data import CiteseerGraphDataset

dataset = CiteseerGraphDataset(raw_dir="")
graph = dataset[0]

train_mask = graph.ndata["train_mask"]
val_mask = graph.ndata["val_mask"]
test_mask = graph.ndata["test_mask"]

node_features = graph.ndata["feat"]
node_labels = graph.ndata["label"]

n_features = node_features.shape[1] # 特征数量
n_labels = int(node_labels.max()+1) # 标签数量
  NumNodes: 3327
  NumEdges: 9228
  NumFeats: 3703
  NumClasses: 6
  NumTrainingSamples: 120
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.

下面是通过使用准确性来评估模型的一个例子。

def evaluate(model, graph, features, labels, mask):
    import torch
    
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

用户可以按如下方式实现模型的训练。

import torch

model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    model.train()
    # 使用所有节点(全图)进行前向传播计算
    logits = model(graph, node_features)
    # 计算损失值
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    # 计算验证集的准确度
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    # 进行反向传播计算
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(epoch,"-",loss.item())

    # 如果需要的话,保存训练好的模型。本例中省略。

DGL的GraphSAGE样例 提供了一个端到端的同构图节点分类的例子。用户可以在 GraphSAGE 类中看到模型实现的细节。 这个模型具有可调节的层数、dropout概率,以及可定制的聚合函数和非线性函数。

完整代码

"""
Inductive Representation Learning on Large Graphs
Paper: http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf
Code: https://github.com/williamleif/graphsage-simple
Simple reference implementation of GraphSAGE.
"""
import argparse
import time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.nn.pytorch.conv import SAGEConv


class GraphSAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

        # input layer
        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))
        # output layer
        self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type)) # activation None

    def forward(self, graph, inputs):
        h = self.dropout(inputs)
        for l, layer in enumerate(self.layers):
            h = layer(graph, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h


def evaluate(model, graph, features, labels, nid):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[nid]
        labels = labels[nid]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

def main(args):
    # load and preprocess dataset
    data = load_data(args)
    g = data[0]
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    in_feats = features.shape[1]
    n_classes = data.num_classes
    n_edges = data.graph.number_of_edges()
    print("""----Data statistics------'
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (n_edges, n_classes,
           train_mask.int().sum().item(),
           val_mask.int().sum().item(),
           test_mask.int().sum().item()))

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
        features = features.cuda()
        labels = labels.cuda()
        train_mask = train_mask.cuda()
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()
        print("use cuda:", args.gpu)

    train_nid = train_mask.nonzero().squeeze()
    val_nid = val_mask.nonzero().squeeze()
    test_nid = test_mask.nonzero().squeeze()

    # graph preprocess and calculate normalization factor
    g = dgl.remove_self_loop(g)
    n_edges = g.number_of_edges()
    if cuda:
        g = g.int().to(args.gpu)

    # create GraphSAGE model
    model = GraphSAGE(in_feats,
                      args.n_hidden,
                      n_classes,
                      args.n_layers,
                      F.relu,
                      args.dropout,
                      args.aggregator_type)

    if cuda:
        model.cuda()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        logits = model(g, features)
        loss = F.cross_entropy(logits[train_nid], labels[train_nid])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        acc = evaluate(model, g, features, labels, val_nid)
        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
              "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
                                            acc, n_edges / np.mean(dur) / 1000))

    print()
    acc = evaluate(model, g, features, labels, test_nid)
    print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GraphSAGE')
    register_data_args(parser)
    parser.add_argument("--dropout", type=float, default=0.5,
                        help="dropout probability")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2,
                        help="learning rate")
    parser.add_argument("--n-epochs", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
                        help="number of hidden gcn units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden gcn layers")
    parser.add_argument("--weight-decay", type=float, default=5e-4,
                        help="Weight for L2 loss")
    parser.add_argument("--aggregator-type", type=str, default="gcn",
                        help="Aggregator type: mean/gcn/pool/lstm")
    
    args = parser.parse_known_args()[0]

"""
start:为了使得代码能在jupyter中直接使用添加了以下代码,可选数据库为cora|citeseer|pubmed
"""
#     args.dataset = "cora"
    args.dataset = "citeseer"
#     args.dataset = "pubmed"
"""
end
"""

#     args = parser.parse_args()
    print(args)

    main(args)

Namespace(aggregator_type=‘gcn’, dataset=‘citeseer’, dropout=0.5, gpu=-1, lr=0.01, n_epochs=200, n_hidden=16, n_layers=1, weight_decay=0.0005)
NumNodes: 3327
NumEdges: 9228
NumFeats: 3703
NumClasses: 6
NumTrainingSamples: 120
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
----Data statistics------’
#Edges 9228
#Classes 6
#Train samples 120
#Val samples 500
#Test samples 1000
Epoch 00000 | Time(s) nan | Loss 1.7920 | Accuracy 0.1880 | ETputs(KTEPS) nan
Epoch 00001 | Time(s) nan | Loss 1.7858 | Accuracy 0.1920 | ETputs(KTEPS) nan
Epoch 00002 | Time(s) nan | Loss 1.7786 | Accuracy 0.1900 | ETputs(KTEPS) nan
Epoch 00003 | Time(s) 0.1539 | Loss 1.7712 | Accuracy 0.2060 | ETputs(KTEPS) 59.15
Epoch 00004 | Time(s) 0.1404 | Loss 1.7608 | Accuracy 0.2400 | ETputs(KTEPS) 64.83
Epoch 00005 | Time(s) 0.1406 | Loss 1.7520 | Accuracy 0.2680 | ETputs(KTEPS) 64.76
Epoch 00006 | Time(s) 0.1394 | Loss 1.7448 | Accuracy 0.3040 | ETputs(KTEPS) 65.30

Epoch 00197 | Time(s) 0.1541 | Loss 0.4280 | Accuracy 0.6960 | ETputs(KTEPS) 59.06
Epoch 00198 | Time(s) 0.1540 | Loss 0.4434 | Accuracy 0.6980 | ETputs(KTEPS) 59.11
Epoch 00199 | Time(s) 0.1540 | Loss 0.4104 | Accuracy 0.7020 | ETputs(KTEPS) 59.12
Test Accuracy 0.6970

异构图上的节点分类模型的训练

如果图是异构的,用户可能希望沿着所有边类型从邻居那里收集消息。 用户可以使用 dgl.nn.pytorch.HeteroGraphConv 模块(针对MXNet和PyTorch后端也有对应的模块)在所有边类型上执行消息传递, 并为每种边类型使用一种图卷积模块。

下面的代码定义了一个异构图卷积模块。模块首先对每种边类型进行单独的图卷积计算,然后将每种边类型上的消息聚合结果再相加, 并作为所有节点类型的最终结果。

# Define a Heterograph Conv model

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        # 实例化HeteroGraphConv,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregate是聚合函数的类型
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # 输入是节点的特征字典
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

dgl.nn.HeteroGraphConv 接收一个节点类型和节点特征张量的字典作为输入,并返回另一个节点类型和节点特征的字典。

本章的 异构图训练的样例数据 中已经有了 useritem 的特征,用户可用如下代码获取

model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']

然后,用户可以简单地按如下形式进行前向传播计算:

node_features = {'user': user_feats, 'item': item_feats}
h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})
h_user = h_dict['user']
h_item = h_dict['item']

异构图上模型的训练和同构图的模型训练是一样的,只是这里使用了一个包括节点表示的字典来计算预测值。 例如,如果只预测 user 节点的类别,用户可以从返回的字典中提取 user 的节点嵌入。

opt = torch.optim.Adam(model.parameters())

for epoch in range(5):
    model.train()
    # 使用所有节点的特征进行前向传播计算,并提取输出的user节点嵌入
    logits = model(hetero_graph, node_features)['user']
    # 计算损失值
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    # 计算验证集的准确度。在本例中省略。
    # 进行反向传播计算
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # 如果需要的话,保存训练好的模型。本例中省略。
1.8446298837661743
1.8297122716903687
1.8156903982162476
1.80256187915802
1.7902932167053223

DGL提供了一个用于节点分类的RGCN的端到端的例子 RGCN 。用户可以在 RGCN模型实现文件 中查看异构图卷积 RelGraphConvLayer 的具体定义。

5.2 边分类/回归

5.2 边分类/回归

5.3 链接预测

5.3 链接预测

5.4 整图分类

5.4 整图分类

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值