图神经网络之三:节点表征学习(未完待续)

本文通过Cora数据集展示了MLP、GCN和GAT在节点分类任务中的表现,GCN和GAT利用图结构信息提升节点表征质量,而MLP仅依赖节点内容。GAT的注意力机制使图卷积更灵活,GCN则通过邻接矩阵平滑节点特征。结果表明,图神经网络在节点分类上优于传统深度神经网络。
摘要由CSDN通过智能技术生成

在图节点预测或边预测任务中,首先需要生成节点表征(Node Representation)。我们使用图神经网络来生成节点表征,并通过基于监督学习的对图神经网络的训练,使得图神经网络学会产生高质量的节点表征高质量的节点表征能够用于衡量节点的相似性,同时高质量的节点表征也是准确分类节点的前提。

本节中,我们将学习实现多层图神经网络的方法,并以节点分类任务为例,学习训练图神经网络的一般过程。我们将以Cora数据集为例子进行说明,Cora是一个论文引用网络,节点代表论文,如果两篇论文存在引用关系,则对应的两个节点之间存在边,各节点的属性都是一个1433维的词包特征向量。我们的任务是预测各篇论文的类别(共7类)。我们还将对MLP和GCN, GAT(两个知名度很高的图神经网络)三类神经网络在节点分类任务中的表现进行比较分析,以此来展现图神经网络的强大和论证图神经网络强于普通深度神经网络的原因。

此节内容安排如下:

  1. 首先,我们要做一些准备工作,即获取并分析数据集构建一个方法用于分析节点表征的分布
  2. 然后,我们考察MLP神经网络用于节点分类的表现,并观察基于MLP神经网络学习到的节点表征的分布
  3. 接着,我们逐一介绍GCN, GAT这两个图神经网络的理论、对比它们在节点分类任务中的表现以及它们学习到的节点表征的质量
  4. 最后,我们比较三者在节点表征学习能力上的差异

为了展现图神经网络的强大,我们通过节点分类任务来比较MLP和GCN、GAT(两个知名度很高的图神经网络)三者的节点表征学习能力

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='dataset', name='Cora', transform=NormalizeFeatures())

print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('======================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Dataset: Cora():
======================
Number of graphs: 1
Number of features: 1433
Number of classes: 7

Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
======================
Number of nodes: 2708
Number of edges: 10556
Average node degree: 3.90
Number of training nodes: 140
Training node label rate: 0.05
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(out.detach().cpu().numpy())
    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

一、MLP神经网络进行节点分类

理论上,我们应该能够仅根据文章的内容,即它的词包特征表征(bag-of-words feature representation)来推断文章的类别,而无需考虑文章之间的任何关系信息。接下来,让我们通过构建一个简单的MLP神经网络来验证这一点。此神经网络只对输入节点的表征做变换,它在所有节点之间共享权重。

MLP神经网络的构造

import torch
from torch.nn import Linear
import torch.nn.functional as F

class MLP(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(MLP, self).__init__()
        torch.manual_seed(12345)
        self.lin1 = Linear(dataset.num_features, hidden_channels)
        self.lin2 = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x):
        x = self.lin1(x)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

model = MLP(hidden_channels=16)
print(model)
MLP(
  (lin1): Linear(in_features=1433, out_features=16, bias=True)
  (lin2): Linear(in_features=16, out_features=7, bias=True)
)
model = MLP(hidden_channels=16)
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)  # Define optimizer.

def train():
    model.train()
    optimizer.zero_grad()  # Clear gradients.
    out = model(data.x)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss

for epoch in range(1, 201):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
Epoch: 001, Loss: 1.9615
Epoch: 002, Loss: 1.9557
Epoch: 003, Loss: 1.9505
Epoch: 004, Loss: 1.9423
Epoch: 005, Loss: 1.9327
Epoch: 006, Loss: 1.9279
Epoch: 007, Loss: 1.9144
Epoch: 008, Loss: 1.9087
Epoch: 009, Loss: 1.9023
Epoch: 010, Loss: 1.8893
Epoch: 011, Loss: 1.8776
Epoch: 012, Loss: 1.8594
Epoch: 013, Loss: 1.8457
Epoch: 014, Loss: 1.8365
Epoch: 015, Loss: 1.8280
Epoch: 016, Loss: 1.7965
Epoch: 017, Loss: 1.7984
Epoch: 018, Loss: 1.7832
Epoch: 019, Loss: 1.7495
Epoch: 020, Loss: 1.7441
Epoch: 021, Loss: 1.7188
Epoch: 022, Loss: 1.7124
Epoch: 023, Loss: 1.6785
Epoch: 024, Loss: 1.6660
Epoch: 025, Loss: 1.6119
Epoch: 026, Loss: 1.6236
Epoch: 027, Loss: 1.5827
Epoch: 028, Loss: 1.5784
Epoch: 029, Loss: 1.5524
Epoch: 030, Loss: 1.5020
Epoch: 031, Loss: 1.5065
Epoch: 032, Loss: 1.4742
Epoch: 033, Loss: 1.4581
Epoch: 034, Loss: 1.4246
Epoch: 035, Loss: 1.4131
Epoch: 036, Loss: 1.4112
Epoch: 037, Loss: 1.3923
Epoch: 038, Loss: 1.3055
Epoch: 039, Loss: 1.2982
Epoch: 040, Loss: 1.2543
Epoch: 041, Loss: 1.2244
Epoch: 042, Loss: 1.2331
Epoch: 043, Loss: 1.1984
Epoch: 044, Loss: 1.1796
Epoch: 045, Loss: 1.1093
Epoch: 046, Loss: 1.1284
Epoch: 047, Loss: 1.1229
Epoch: 048, Loss: 1.0383
Epoch: 049, Loss: 1.0439
Epoch: 050, Loss: 1.0563
Epoch: 051, Loss: 0.9893
Epoch: 052, Loss: 1.0508
Epoch: 053, Loss: 0.9343
Epoch: 054, Loss: 0.9639
Epoch: 055, Loss: 0.8929
Epoch: 056, Loss: 0.8705
Epoch: 057, Loss: 0.9176
Epoch: 058, Loss: 0.9239
Epoch: 059, Loss: 0.8641
Epoch: 060, Loss: 0.8578
Epoch: 061, Loss: 0.7908
Epoch: 062, Loss: 0.7856
Epoch: 063, Loss: 0.7683
Epoch: 064, Loss: 0.7816
Epoch: 065, Loss: 0.7356
Epoch: 066, Loss: 0.6951
Epoch: 067, Loss: 0.7300
Epoch: 068, Loss: 0.6939
Epoch: 069, Loss: 0.7550
Epoch: 070, Loss: 0.6864
Epoch: 071, Loss: 0.7094
Epoch: 072, Loss: 0.7238
Epoch: 073, Loss: 0.7150
Epoch: 074, Loss: 0.6191
Epoch: 075, Loss: 0.6770
Epoch: 076, Loss: 0.6487
Epoch: 077, Loss: 0.6258
Epoch: 078, Loss: 0.5821
Epoch: 079, Loss: 0.5637
Epoch: 080, Loss: 0.6368
Epoch: 081, Loss: 0.6333
Epoch: 082, Loss: 0.6434
Epoch: 083, Loss: 0.5974
Epoch: 084, Loss: 0.6176
Epoch: 085, Loss: 0.5972
Epoch: 086, Loss: 0.4690
Epoch: 087, Loss: 0.6362
Epoch: 088, Loss: 0.6118
Epoch: 089, Loss: 0.5248
Epoch: 090, Loss: 0.5520
Epoch: 091, Loss: 0.6130
Epoch: 092, Loss: 0.5361
Epoch: 093, Loss: 0.5594
Epoch: 094, Loss: 0.5049
Epoch: 095, Loss: 0.5043
Epoch: 096, Loss: 0.5235
Epoch: 097, Loss: 0.5451
Epoch: 098, Loss: 0.5329
Epoch: 099, Loss: 0.5008
Epoch: 100, Loss: 0.5350
Epoch: 101, Loss: 0.5343
Epoch: 102, Loss: 0.5138
Epoch: 103, Loss: 0.5377
Epoch: 104, Loss: 0.5353
Epoch: 105, Loss: 0.5176
Epoch: 106, Loss: 0.5229
Epoch: 107, Loss: 0.4558
Epoch: 108, Loss: 0.4883
Epoch: 109, Loss: 0.4659
Epoch: 110, Loss: 0.4908
Epoch: 111, Loss: 0.4966
Epoch: 112, Loss: 0.4725
Epoch: 113, Loss: 0.4787
Epoch: 114, Loss: 0.4390
Epoch: 115, Loss: 0.4199
Epoch: 116, Loss: 0.4810
Epoch: 117, Loss: 0.4484
Epoch: 118, Loss: 0.5080
Epoch: 119, Loss: 0.4241
Epoch: 120, Loss: 0.4745
Epoch: 121, Loss: 0.4651
Epoch: 122, Loss: 0.4652
Epoch: 123, Loss: 0.5580
Epoch: 124, Loss: 0.4861
Epoch: 125, Loss: 0.4405
Epoch: 126, Loss: 0.4292
Epoch: 127, Loss: 0.4409
Epoch: 128, Loss: 0.3575
Epoch: 129, Loss: 0.4468
Epoch: 130, Loss: 0.4603
Epoch: 131, Loss: 0.4108
Epoch: 132, Loss: 0.4601
Epoch: 133, Loss: 0.4258
Epoch: 134, Loss: 0.3852
Epoch: 135, Loss: 0.4028
Epoch: 136, Loss: 0.4245
Epoch: 137, Loss: 0.4300
Epoch: 138, Loss: 0.4693
Epoch: 139, Loss: 0.4314
Epoch: 140, Loss: 0.4031
Epoch: 141, Loss: 0.4290
Epoch: 142, Loss: 0.4110
Epoch: 143, Loss: 0.3863
Epoch: 144, Loss: 0.4215
Epoch: 145, Loss: 0.4519
Epoch: 146, Loss: 0.3940
Epoch: 147, Loss: 0.4429
Epoch: 148, Loss: 0.3527
Epoch: 149, Loss: 0.4390
Epoch: 150, Loss: 0.4212
Epoch: 151, Loss: 0.4128
Epoch: 152, Loss: 0.3779
Epoch: 153, Loss: 0.4801
Epoch: 154, Loss: 0.4130
Epoch: 155, Loss: 0.3962
Epoch: 156, Loss: 0.4262
Epoch: 157, Loss: 0.4210
Epoch: 158, Loss: 0.4081
Epoch: 159, Loss: 0.4066
Epoch: 160, Loss: 0.3782
Epoch: 161, Loss: 0.3836
Epoch: 162, Loss: 0.4172
Epoch: 163, Loss: 0.3993
Epoch: 164, Loss: 0.4477
Epoch: 165, Loss: 0.3714
Epoch: 166, Loss: 0.3610
Epoch: 167, Loss: 0.4546
Epoch: 168, Loss: 0.4387
Epoch: 169, Loss: 0.3793
Epoch: 170, Loss: 0.3704
Epoch: 171, Loss: 0.4286
Epoch: 172, Loss: 0.4131
Epoch: 173, Loss: 0.3795
Epoch: 174, Loss: 0.4230
Epoch: 175, Loss: 0.4139
Epoch: 176, Loss: 0.3586
Epoch: 177, Loss: 0.3588
Epoch: 178, Loss: 0.3911
Epoch: 179, Loss: 0.3810
Epoch: 180, Loss: 0.4203
Epoch: 181, Loss: 0.3583
Epoch: 182, Loss: 0.3690
Epoch: 183, Loss: 0.4025
Epoch: 184, Loss: 0.3920
Epoch: 185, Loss: 0.4369
Epoch: 186, Loss: 0.4317
Epoch: 187, Loss: 0.4911
Epoch: 188, Loss: 0.3369
Epoch: 189, Loss: 0.4945
Epoch: 190, Loss: 0.3912
Epoch: 191, Loss: 0.3824
Epoch: 192, Loss: 0.3479
Epoch: 193, Loss: 0.3798
Epoch: 194, Loss: 0.3799
Epoch: 195, Loss: 0.4015
Epoch: 196, Loss: 0.3615
Epoch: 197, Loss: 0.3985
Epoch: 198, Loss: 0.4664
Epoch: 199, Loss: 0.3714
Epoch: 200, Loss: 0.3810
def test():
    model.eval()
    out = model(data.x)
    pred = out.argmax(dim=1)  # Use the class with highest probability.
    test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.
    test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.
    return test_acc

test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
Test Accuracy: 0.5900

二、卷积图神经网络(GCN)

GCN的定义

GCN 来源于论文“Semi-supervised Classification with Graph Convolutional Network”,其数学定义为,
X ′ = D ^ − 1 / 2 A ^ D ^ − 1 / 2 X Θ , \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, X=D^1/2A^D^1/2XΘ,
其中 A ^ = A + I \mathbf{\hat{A}} = \mathbf{A} + \mathbf{I} A^=A+I表示插入自环的邻接矩阵(使得每一个节点都有一条边连接到自身), D ^ i i = ∑ j = 0 A ^ i j \hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij} D^ii=j=0A^ij表示 A ^ \mathbf{\hat{A}} A^的对角线度矩阵(对角线元素为对应节点的度,其余元素为0)。邻接矩阵可以包括不为 1 1 1的值,当邻接矩阵不为{0,1}值时,表示邻接矩阵存储的是边的权重。 D ^ − 1 / 2 A ^ D ^ − 1 / 2 \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} D^1/2A^D^1/2是对称归一化矩阵,它的节点式表述为:
x i ′ = Θ ∑ j ∈ N ( v ) ∪ { i } e j , i d ^ j d ^ i x j \mathbf{x}^{\prime}_i = \mathbf{\Theta} \sum_{j \in \mathcal{N}(v) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j xi=ΘjN(v){i}d^jd^i ej,ixj
其中, d ^ i = 1 + ∑ j ∈ N ( i ) e j , i \hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i} d^i=1+jN(i)ej,i e j , i e_{j,i} ej,i表示从源节点 j j j到目标节点 i i i的边的对称归一化系数(默认值为1.0)

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 13 22:39:17 2021
@author: Choi
"""


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid # PyG处理好的一些数据,如"Cora", "CiteSeer" and "PubMed" ,用Planetoid这个类调用即可
import torch_geometric.nn as pyg_nn


# 第一步:准备数据,load Cora dataset
def get_data(folder="node_classify/cora", data_name="cora"):
    """
    :param folder:保存数据集的根目录。
    :param data_name:数据集的名称
    :return:返回的是一个对象,就是PyG文档里的Data对象,它有一些属性,如 data.x、data.edge_index等
    """
    dataset = Planetoid(root=folder, name=data_name)
    return dataset

# 第二步:定义模型,create the graph cnn model
class GraphCNN(nn.Module):
    def __init__(self, in_c, hid_c, out_c):
        super(GraphCNN, self).__init__()  # 表示子类GraphCNN继承了父类nn.Module的所有属性和方法.
        # 下面这个就是前面讲的GCN,参数只有输入和输出,定义了两层的GCN.
        self.conv1 = pyg_nn.GCNConv(in_channels=in_c, out_channels=hid_c)
        self.conv2 = pyg_nn.GCNConv(in_channels=hid_c, out_channels=out_c)

    def forward(self, data):
        # data.x  data.edge_index
        x = data.x  # [N, C], C为特征的维度
        edge_index = data.edge_index  # [2, E], E为边的数量
        hid = self.conv1(x=x, edge_index=edge_index)  # [N, D], N是节点数量,D是第一层输出的隐藏层的维度
        hid = F.relu(hid)

        out = self.conv2(x=hid, edge_index=edge_index)  # [N, out_c], out_c就是定义的输出,比如分成几类就是几,这里是7

        out = F.log_softmax(out, dim=1)  # [N, out_c],表示输出

        return out


# todo list
class YouOwnGCN(nn.Module):  # 这个不用理会,如果之后想用别的图卷积实现,可以自己在这里写,然后调用
    pass
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 配置GPU
    cora_dataset = get_data()

    # todo list
    # 这个是自己写的网络的实例化
    my_net = GraphCNN(in_c=cora_dataset.num_node_features, hid_c=13, out_c=cora_dataset.num_classes)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   # 检查设备

    my_net = my_net.to(device)  # 模型送入设备
    data = cora_dataset[0].to(device)  # 数据送入设备,也就是一张图
    
    # 第三步:定义损失函数和优化器
    optimizer = torch.optim.Adam(my_net.parameters(), lr=1e-3)  # 优化器
    
    # 第四步:训练+测试
    # model train,这个train就是说归一化等可以重复使用,而设置成eval则就不行了,表示测试
    my_net.train()
    for epoch in range(200):
        optimizer.zero_grad()  # 每次缓存之后清零,不然梯度会累加

        output = my_net(data)  # 预测结果

        loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])  # 意思就是只取训练集
        loss.backward()

        print("epoch:", epoch + 1, loss.item())
        optimizer.step()  # 优化器

    # model test
    my_net.eval()
    _, prediction = my_net(data).max(dim=1)

    target = data.y

    test_correct = prediction[data.test_mask].eq(target[data.test_mask]).sum().item()
    test_number = data.test_mask.sum().item()

    print("Accuracy of Test Samples:{}%".format(100*test_correct/test_number))

if __name__ == "__main__":
    main()

epoch: 1 1.9453116655349731
epoch: 2 1.932291030883789
epoch: 3 1.919562816619873
epoch: 4 1.9069414138793945
epoch: 5 1.894279956817627
epoch: 6 1.8813810348510742
epoch: 7 1.8680614233016968
epoch: 8 1.854246735572815
epoch: 9 1.8398774862289429
epoch: 10 1.8251683712005615
epoch: 11 1.8101780414581299
epoch: 12 1.7949455976486206
epoch: 13 1.7795181274414062
epoch: 14 1.7638221979141235
epoch: 15 1.7478944063186646
epoch: 16 1.7317240238189697
epoch: 17 1.7153429985046387
epoch: 18 1.698785424232483
epoch: 19 1.6820697784423828
epoch: 20 1.6652289628982544
epoch: 21 1.6482977867126465
epoch: 22 1.6313109397888184
epoch: 23 1.6142857074737549
epoch: 24 1.597234845161438
epoch: 25 1.580147385597229
epoch: 26 1.5630677938461304
epoch: 27 1.5460320711135864
epoch: 28 1.5290590524673462
epoch: 29 1.512131929397583
epoch: 30 1.4952478408813477
epoch: 31 1.4784393310546875
epoch: 32 1.4616984128952026
epoch: 33 1.4450348615646362
epoch: 34 1.4284635782241821
epoch: 35 1.4119902849197388
epoch: 36 1.3956199884414673
epoch: 37 1.3793565034866333
epoch: 38 1.3631986379623413
epoch: 39 1.3471466302871704
epoch: 40 1.3312034606933594
epoch: 41 1.3153629302978516
epoch: 42 1.299630045890808
epoch: 43 1.284001350402832
epoch: 44 1.2684894800186157
epoch: 45 1.253075122833252
epoch: 46 1.237774133682251
epoch: 47 1.2225756645202637
epoch: 48 1.2074928283691406
epoch: 49 1.192516565322876
epoch: 50 1.1776549816131592
epoch: 51 1.162906527519226
epoch: 52 1.1482723951339722
epoch: 53 1.1337566375732422
epoch: 54 1.1193455457687378
epoch: 55 1.105042576789856
epoch: 56 1.0908522605895996
epoch: 57 1.0767767429351807
epoch: 58 1.0628098249435425
epoch: 59 1.0489519834518433
epoch: 60 1.0352014303207397
epoch: 61 1.021569848060608
epoch: 62 1.0080524682998657
epoch: 63 0.9946502447128296
epoch: 64 0.9813603162765503
epoch: 65 0.9681810736656189
epoch: 66 0.9551102519035339
epoch: 67 0.9421456456184387
epoch: 68 0.9292967915534973
epoch: 69 0.9165646433830261
epoch: 70 0.9039415717124939
epoch: 71 0.8914322257041931
epoch: 72 0.879046618938446
epoch: 73 0.8667768239974976
epoch: 74 0.854621946811676
epoch: 75 0.8425841927528381
epoch: 76 0.8306671380996704
epoch: 77 0.8188669085502625
epoch: 78 0.8071789145469666
epoch: 79 0.7956063151359558
epoch: 80 0.7841473817825317
epoch: 81 0.772806704044342
epoch: 82 0.7615860104560852
epoch: 83 0.7504868507385254
epoch: 84 0.7395078539848328
epoch: 85 0.7286520004272461
epoch: 86 0.7179200649261475
epoch: 87 0.7073159217834473
epoch: 88 0.6968382596969604
epoch: 89 0.6864873170852661
epoch: 90 0.6762606501579285
epoch: 91 0.6661574244499207
epoch: 92 0.6561772227287292
epoch: 93 0.6463234424591064
epoch: 94 0.6365894079208374
epoch: 95 0.6269803047180176
epoch: 96 0.6174972057342529
epoch: 97 0.608134388923645
epoch: 98 0.5988974571228027
epoch: 99 0.58978670835495
epoch: 100 0.5808026790618896
epoch: 101 0.571944534778595
epoch: 102 0.5632098317146301
epoch: 103 0.5545985698699951
epoch: 104 0.5461116433143616
epoch: 105 0.5377492308616638
epoch: 106 0.529509425163269
epoch: 107 0.5213926434516907
epoch: 108 0.5133975148200989
epoch: 109 0.5055218935012817
epoch: 110 0.4977670907974243
epoch: 111 0.4901328980922699
epoch: 112 0.4826175272464752
epoch: 113 0.4752180874347687
epoch: 114 0.4679345190525055
epoch: 115 0.4607686698436737
epoch: 116 0.4537140130996704
epoch: 117 0.44677218794822693
epoch: 118 0.4399445950984955
epoch: 119 0.43322959542274475
epoch: 120 0.4266222417354584
epoch: 121 0.420122891664505
epoch: 122 0.4137340784072876
epoch: 123 0.40745434165000916
epoch: 124 0.40128105878829956
epoch: 125 0.39520952105522156
epoch: 126 0.3892422914505005
epoch: 127 0.38337457180023193
epoch: 128 0.37760627269744873
epoch: 129 0.3719368577003479
epoch: 130 0.3663651943206787
epoch: 131 0.36088690161705017
epoch: 132 0.3555026650428772
epoch: 133 0.35021278262138367
epoch: 134 0.3450167179107666
epoch: 135 0.33991098403930664
epoch: 136 0.3348931670188904
epoch: 137 0.32996195554733276
epoch: 138 0.3251176178455353
epoch: 139 0.32035914063453674
epoch: 140 0.3156838119029999
epoch: 141 0.3110896646976471
epoch: 142 0.30657675862312317
epoch: 143 0.3021458089351654
epoch: 144 0.29779067635536194
epoch: 145 0.2935124635696411
epoch: 146 0.28931015729904175
epoch: 147 0.2851811945438385
epoch: 148 0.28112462162971497
epoch: 149 0.2771369516849518
epoch: 150 0.2732200622558594
epoch: 151 0.269372820854187
epoch: 152 0.2655932903289795
epoch: 153 0.26188188791275024
epoch: 154 0.2582346200942993
epoch: 155 0.2546558082103729
epoch: 156 0.2511393129825592
epoch: 157 0.24768483638763428
epoch: 158 0.24429070949554443
epoch: 159 0.24095328152179718
epoch: 160 0.23767392337322235
epoch: 161 0.23445293307304382
epoch: 162 0.23128703236579895
epoch: 163 0.2281753271818161
epoch: 164 0.22511635720729828
epoch: 165 0.2221105992794037
epoch: 166 0.21915863454341888
epoch: 167 0.21625804901123047
epoch: 168 0.21340808272361755
epoch: 169 0.21060511469841003
epoch: 170 0.20785026252269745
epoch: 171 0.20514263212680817
epoch: 172 0.2024814933538437
epoch: 173 0.19986717402935028
epoch: 174 0.19729600846767426
epoch: 175 0.19476723670959473
epoch: 176 0.1922818273305893
epoch: 177 0.18984107673168182
epoch: 178 0.18744073808193207
epoch: 179 0.18507839739322662
epoch: 180 0.1827564388513565
epoch: 181 0.18047228455543518
epoch: 182 0.17822620272636414
epoch: 183 0.17601743340492249
epoch: 184 0.1738460212945938
epoch: 185 0.1717095524072647
epoch: 186 0.16960883140563965
epoch: 187 0.16754303872585297
epoch: 188 0.16551196575164795
epoch: 189 0.1635134518146515
epoch: 190 0.1615469753742218
epoch: 191 0.15961284935474396
epoch: 192 0.1577097475528717
epoch: 193 0.15583734214305878
epoch: 194 0.15399456024169922
epoch: 195 0.15218158066272736
epoch: 196 0.15039794147014618
epoch: 197 0.14864230155944824
epoch: 198 0.14691464602947235
epoch: 199 0.1452144831418991
epoch: 200 0.14354094862937927
Accuracy of Test Samples:79.1%

Accuracy of Test Samples:79.1%

三、图注意力神经网络(GAT)

背景

图卷积网络(GCN)存在一定缺陷。GCN只能应用于转导(transductive)任务,无法完成动态图处理(inductive)。且由于傅立叶变换推导的局限性,难以处理有向图。因此,需要一种更完善的图卷积算法。

注意力机制的引入

Yoshua Bengio团队在CNN的基础上引入masked self-attention,提出了图注意力网络(GAT)1,图中的每个节点可以根据邻居节点的特征,为其分配不同的权值,并且无需使用预先构建好的图。

GAT的结构

图注意力层
  首先来介绍单个的图注意力网络层。
  图注意力层的结构如下图所示:
GAT结构示意图

在这里插入图片描述

论文出处:

GAT来源于论文 Graph Attention Networks。其数学定义为,
x i ′ = α i , i Θ x i + ∑ j ∈ N ( i ) α i , j Θ x j , \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, xi=αi,iΘxi+jN(i)αi,jΘxj,
其中注意力系数 α i , j \alpha_{i,j} αi,j的计算方法为,
α i , j = exp ⁡ ( L e a k y R e L U ( a ⊤ [ Θ x i   ∥   Θ x j ] ) ) ∑ k ∈ N ( i ) ∪ { i } exp ⁡ ( L e a k y R e L U ( a ⊤ [ Θ x i   ∥   Θ x k ] ) ) . \alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] \right)\right)}. αi,j=kN(i){i}exp(LeakyReLU(a[ΘxiΘxk]))exp(LeakyReLU(a[ΘxiΘxj])).

PyG中GATConv 模块说明

GATConv构造函数接口:

"""
- `in_channels `:输入数据维度;
- `out_channels `:输出数据维度;
- `heads `:在`GATConv`使用多少个注意力模型(Number of multi-head-attentions);
- `concat `:如为`true`,不同注意力模型得到的节点表征被拼接到一起(表征维度翻倍),否则对不同注意力模型得到的节点表征求均值;
"""
GATConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, bias: bool = True, **kwargs)

详细内容请大家参阅GATConv官方文档

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 13 22:39:17 2021
@author: Choi
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
import torch_geometric.nn as pyg_nn


# load Cora dataset
def get_data(folder="node_classify/cora", data_name="cora"):
    dataset = Planetoid(root=folder, name=data_name, pre_transform=T.KNNGraph(k=6), transform=T.NormalizeFeatures())
    return dataset

# create the graph cnn model
class GraphCNN(nn.Module):
    def __init__(self, in_c, hid_c, out_c):
        super(GraphCNN, self).__init__()  # 表示子类GraphCNN继承了父类nn.Module的所有属性和方法
        self.conv1 = pyg_nn.GATConv(in_channels=in_c, out_channels=hid_c, dropout=0.6)
        self.conv2 = pyg_nn.GATConv(in_channels=hid_c, out_channels=out_c, dropout=0.6, heads=1, concat=True)

    def forward(self, data):
        # data.x  data.edge_index
        x = data.x  # [N, C], C为特征的维度
        edge_index = data.edge_index  # [2, E], E为边的数量
        x = F.dropout(x, p=0.6, training=self.training)
        hid = self.conv1(x=x, edge_index=edge_index)  # [N, D], N是节点数量,D是第一层输出的隐藏层的维度
        hid = F.relu(hid)

        out = self.conv2(x=hid, edge_index=edge_index)  # [N, out_c], out_c就是定义的输出,比如分成几类就是几,这里是7

        out = F.log_softmax(out, dim=1)  # [N, out_c],表示输出

        return out

def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 配置GPU
    cora_dataset = get_data()

    my_net = GraphCNN(in_c=cora_dataset.num_node_features, hid_c=8, out_c=cora_dataset.num_classes)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   # 检查设备

    my_net = my_net.to(device)  # 模型送入设备
    data = cora_dataset[0].to(device)  # 数据送入设备,也就是一张图

    optimizer = torch.optim.Adam(my_net.parameters(), lr=5e-3)  # 优化器

    # model train,这个train就是说归一化等可以重复使用,而设置成eval则就不行了,表示测试
    my_net.train()
    for epoch in range(200):
        optimizer.zero_grad()  # 每次缓存之后清零,不然梯度会累加

        output = my_net(data)  # 预测结果

        loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])  # 意思就是只取训练集
        loss.backward()

        print("epoch", epoch + 1, loss.item())
        optimizer.step()  # 优化器

    # model test
    my_net.eval()
    _, prediction = my_net(data).max(dim=1)

    target = data.y

    test_correct = prediction[data.test_mask].eq(target[data.test_mask]).sum().item()
    test_number = data.test_mask.sum().item()

    print("Accuracy of Test Samples:{}%".format(100*test_correct/test_number))

if __name__ == "__main__":
    main()

WARNING:root:The `pre_transform` argument differs from the one used in the pre-processed version of this dataset. If you really want to make use of another pre-processing technique, make sure to delete `node_classify\cora\cora\processed` first.
epoch 1 1.9466304779052734
epoch 2 1.9452993869781494
epoch 3 1.9455339908599854
epoch 4 1.94290030002594
epoch 5 1.9421229362487793
epoch 6 1.9380170106887817
epoch 7 1.9406156539916992
epoch 8 1.9380996227264404
epoch 9 1.9345784187316895
epoch 10 1.9321197271347046
epoch 11 1.9269548654556274
epoch 12 1.9265786409378052
epoch 13 1.9235918521881104
epoch 14 1.9232169389724731
epoch 15 1.9175101518630981
epoch 16 1.9152055978775024
epoch 17 1.904789924621582
epoch 18 1.9089778661727905
epoch 19 1.9022480249404907
epoch 20 1.9059253931045532
epoch 21 1.901972770690918
epoch 22 1.8945059776306152
epoch 23 1.8917443752288818
epoch 24 1.8988347053527832
epoch 25 1.8938117027282715
epoch 26 1.8805230855941772
epoch 27 1.8783419132232666
epoch 28 1.8790854215621948
epoch 29 1.8758597373962402
epoch 30 1.8523036241531372
epoch 31 1.871390700340271
epoch 32 1.840027928352356
epoch 33 1.8444995880126953
epoch 34 1.8628698587417603
epoch 35 1.8366267681121826
epoch 36 1.8411531448364258
epoch 37 1.8486825227737427
epoch 38 1.8458245992660522
epoch 39 1.837032675743103
epoch 40 1.820558786392212
epoch 41 1.8241647481918335
epoch 42 1.7830429077148438
epoch 43 1.7953040599822998
epoch 44 1.8059921264648438
epoch 45 1.814132809638977
epoch 46 1.7854561805725098
epoch 47 1.8005499839782715
epoch 48 1.7478634119033813
epoch 49 1.7778631448745728
epoch 50 1.7594667673110962
epoch 51 1.7636114358901978
epoch 52 1.7309726476669312
epoch 53 1.7277470827102661
epoch 54 1.7192648649215698
epoch 55 1.7023435831069946
epoch 56 1.729498267173767
epoch 57 1.7269208431243896
epoch 58 1.7367229461669922
epoch 59 1.692749261856079
epoch 60 1.7086790800094604
epoch 61 1.6986669301986694
epoch 62 1.6935926675796509
epoch 63 1.7159898281097412
epoch 64 1.664335012435913
epoch 65 1.6296015977859497
epoch 66 1.608216643333435
epoch 67 1.6530961990356445
epoch 68 1.6244709491729736
epoch 69 1.624013900756836
epoch 70 1.5743077993392944
epoch 71 1.6142922639846802
epoch 72 1.5756471157073975
epoch 73 1.5973869562149048
epoch 74 1.5797481536865234
epoch 75 1.5279662609100342
epoch 76 1.586087942123413
epoch 77 1.5813978910446167
epoch 78 1.5842024087905884
epoch 79 1.5828092098236084
epoch 80 1.474631667137146
epoch 81 1.5950703620910645
epoch 82 1.5598087310791016
epoch 83 1.5500640869140625
epoch 84 1.507140040397644
epoch 85 1.5781816244125366
epoch 86 1.4175552129745483
epoch 87 1.464072823524475
epoch 88 1.5317105054855347
epoch 89 1.3428765535354614
epoch 90 1.3941960334777832
epoch 91 1.518794298171997
epoch 92 1.4353892803192139
epoch 93 1.381829023361206
epoch 94 1.4501067399978638
epoch 95 1.453406810760498
epoch 96 1.4458706378936768
epoch 97 1.3227503299713135
epoch 98 1.4084933996200562
epoch 99 1.4429619312286377
epoch 100 1.4185216426849365
epoch 101 1.4342074394226074
epoch 102 1.3611711263656616
epoch 103 1.3482654094696045
epoch 104 1.3560689687728882
epoch 105 1.4058716297149658
epoch 106 1.3985254764556885
epoch 107 1.5533326864242554
epoch 108 1.4676203727722168
epoch 109 1.3352069854736328
epoch 110 1.3208035230636597
epoch 111 1.448912501335144
epoch 112 1.2946215867996216
epoch 113 1.270990014076233
epoch 114 1.4824904203414917
epoch 115 1.3864918947219849
epoch 116 1.3899513483047485
epoch 117 1.3436856269836426
epoch 118 1.2464419603347778
epoch 119 1.2583202123641968
epoch 120 1.3068723678588867
epoch 121 1.3028730154037476
epoch 122 1.2661311626434326
epoch 123 1.233555793762207
epoch 124 1.3794339895248413
epoch 125 1.3032712936401367
epoch 126 1.4174433946609497
epoch 127 1.2497429847717285
epoch 128 1.3369816541671753
epoch 129 1.257314682006836
epoch 130 1.297139048576355
epoch 131 1.265181541442871
epoch 132 1.3884270191192627
epoch 133 1.2926338911056519
epoch 134 1.2739715576171875
epoch 135 1.199686050415039
epoch 136 1.2682065963745117
epoch 137 1.1905218362808228
epoch 138 1.3132424354553223
epoch 139 1.230311393737793
epoch 140 1.2008371353149414
epoch 141 1.190712571144104
epoch 142 1.159222960472107
epoch 143 1.1686118841171265
epoch 144 1.3327347040176392
epoch 145 1.3072259426116943
epoch 146 1.1695005893707275
epoch 147 1.1509381532669067
epoch 148 1.2430397272109985
epoch 149 1.2273073196411133
epoch 150 1.244320034980774
epoch 151 1.120166540145874
epoch 152 1.1097781658172607
epoch 153 1.2332597970962524
epoch 154 1.169371247291565
epoch 155 1.1736806631088257
epoch 156 1.13922917842865
epoch 157 1.1547086238861084
epoch 158 1.2604109048843384
epoch 159 1.2259297370910645
epoch 160 1.1589984893798828
epoch 161 1.170103907585144
epoch 162 1.2283793687820435
epoch 163 1.2285842895507812
epoch 164 1.301701307296753
epoch 165 1.0803091526031494
epoch 166 1.1776891946792603
epoch 167 1.2095720767974854
epoch 168 1.2360646724700928
epoch 169 1.1875016689300537
epoch 170 1.0908942222595215
epoch 171 1.129482626914978
epoch 172 1.1492669582366943
epoch 173 1.168258786201477
epoch 174 1.158072829246521
epoch 175 1.1558555364608765
epoch 176 1.1341286897659302
epoch 177 1.084506869316101
epoch 178 1.0895988941192627
epoch 179 1.0885933637619019
epoch 180 1.1820029020309448
epoch 181 1.2151626348495483
epoch 182 1.1370195150375366
epoch 183 1.101808786392212
epoch 184 1.0804322957992554
epoch 185 1.04401695728302
epoch 186 1.1930063962936401
epoch 187 1.2779303789138794
epoch 188 1.0886690616607666
epoch 189 1.13993239402771
epoch 190 0.9829103946685791
epoch 191 1.046717882156372
epoch 192 1.1881943941116333
epoch 193 1.0009737014770508
epoch 194 1.1312161684036255
epoch 195 1.1127839088439941
epoch 196 1.114255428314209
epoch 197 1.0537904500961304
epoch 198 1.0096962451934814
epoch 199 1.0265328884124756
epoch 200 1.1225985288619995
Accuracy of Test Samples:79.8%

Accuracy of Test Samples:79.8%

GCN、GAT 、 MLP 比较

GCN 和 MLP 比较

GCNs和MLP相似,都是通过多层网络学习一个节点的特征向量 X i X_i Xi,然后再把这个学到的特征向量送入的一个线性分类器中进行分类任务。一个k层GCN与 k k k 层MLP在应用于图中每个节点的特征向量 X i X_i Xi是相同的,不同之处在于每个节点的隐藏表示在每一层的输入时是取的它的邻居的平均。

每个图中的卷积层和节点表示都是使用三个策略来更新

特征传播
线性转换
逐点非线性激活

在这里插入图片描述

Feature propagation 特征传播

GCN的特征传播是区别MLP的,因为每一层的输入都是节点局部邻居的平均值:

在这里插入图片描述

(2)

用一个简单的矩阵运算来表示公式(2)的更新:

(3)

S表示添加自循环“normalized”的邻接矩阵(实际上并没有归一化)
A˜=A+1是A mathbf{A}A的度矩阵
D˜是A的度矩阵
用公式(2)对所有节点进行同时更新,得到了一个简单的稀疏矩阵乘法:

(4)

这一步平滑了沿着图的边的局部隐藏表示,并最终支持在局部连接的节点之间进行类似的预测。

Feature transformation and nonlinear transition

在局部平滑之后,一个GCN层就等于一个标准的MLP。每一个层对应一个可学习的权重矩阵,所以平滑处理了的隐藏特征表示是线性转换的(后面乘一个参数矩阵是线性的)。最后在逐节点应用一个非线性激活函数,例如ReLU就可以得到输出的特征表示

(5)

分类器

对于节点分类任务,最后一层和MLP相似,都是使用一个softmax分类器预测节点的标签,一个K KK层的GCN的所有节点的类别预测可以写作:

(6)
在这里插入图片描述

参考资料:


  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值