基于Pytorch的图卷积网络GCN实例应用及详解
一、图卷积网络GCN定义
图卷积网络GCN实际上就是特征提取器,只不过GCN的数据对象是图。图的结构一般来说是十分不规则,可以看作是多维的一种数据。GCN精妙地设计了一种从图数据中提取特征的方法,从而让我们可以使用这些特征去对图数据进行节点分类(node classification)、图分类(graph classification)、边预测(link prediction)和获得图的嵌入表示(graph embedding),用途十分广泛。
二、图卷积网络GCN的原理
- 博主学习图卷积网络主要参考下面两篇深入浅出的好文章:
- 第一篇参考文章:点击打开《一文读懂图卷积GCN》文章
- 第二篇参考文章:点击打开《最通俗易懂的图神经网络(GCN)原理详解》文章
- 阅读上面两篇文章需要理解图的定义、图相关矩阵的定义(邻接矩阵、度矩阵、拉普拉斯矩阵、稀疏矩阵COO)、图卷积的通式或者公式的推导发展及意义。
- 若阅读完两篇文章公式推导大家还对 邻接矩阵的归一化操作,通过对邻接矩阵两边乘以节点的度开方然后取逆得到 这个知识点“一知半解”,那么请看下面博主就图卷积网络GCN公式进行举例计算,以此帮助有需要的小伙伴理解,理解的可以选择跳过。
归一化操作目标:对称且归一化的矩阵简单来说就是让矩阵的每一行都相加为1。
三、图卷积网络GCN实现前期准备
PyTorch Geometric (简称PYG)中设计了一种新的表示图数据的存储结构,也是 PyTorch Geometric中实现的各种方法的基本数据形式。GCN在PyTorch Geometric中有已经封装好的模型(当然大家也可以自己用python代码根据GCN的实现原理自己搭建模型,那么可以不使用PYG自带模型),因此可以直接导包再根据自己的数据集或者PyTorch Geometric自带的数据集(如Cora、ENZYMES等)去实现节点分类(node classification)、图分类(graph classification)、边预测(link prediction)和获得图的嵌入表示(graph embedding)等这些案例。
- 下载编辑器和配置程序运行环境:点击打开《基于Windows中学习Deep Learning之搭建Anaconda+Pytorch(Cuda+Cudnn)+Pycharm工具和配置环境完整最简版》文章
- 下载PyTorch Geometric包: 点击打开《基于Pytorch中安装torch_geometric简单详细完整版》文章
- PyTorch Geometric包使用的官方文档:点击打开PyTorch Geometric使用介绍官方网页
四、图卷积网络GCN实现案例分析
首先PYG自带的数据集网上的资料和代码很多,大家第一次练手博主认为可以选择PYG自带的数据集,如Cora等,并且训练预测的结果也是非常不错的,大家理解代码也是极好的,给用户体验感受非常不错,因此博主强烈的推荐一篇文章大家可以去试试:点击打开《[PyG] 1.如何使用GCN完成一个最基本的训练过程(含GCN实现)》文章 。但是另一种情况是用户需要用自己的数据集(比如mat文件)通过图卷积网络GCN去实现一些图预测等目的,所以博主通过大量阅读理解和总结,提供一个已经实现的用自己的数据集去跑GCN模型以实现图预测的案例给大家做个参考。
- 案例目的是构造图卷积网络模型训练后进行图片二分类(0和1)预测。
- mat文件数据集(MATLAB的专属文件)的导入和构造,博主已有的数据集存放在J盘以aidb.mat文件形式保存下来。
- aidb.mat文件数据结构如下图。aidb 是一个数据结构体(struct),包含10364个子结构体(struct),每个子结构体(struct)表示一张图,每张图又包含 nl、am 和 no;nl 表示节点及特征,维度是[20,1];am 表示节点之间的邻接矩阵,维度是[20,20];no 表示图片序号,维度是[1] 。laidb 表示所有图片的分类,结果是0或1,总共有10364张图,维度是[10364] ,注意:每张图的维度是[1] 。
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import mat4py
import scipy.sparse as sp
from torch_geometric.data import Data
import torch_geometric.nn as pyg_nn
from torch_geometric.data import DataLoader
import warnings
warnings.filterwarnings("ignore", category=Warning)
# 数据集读取
data = mat4py.loadmat('J:/aidb.mat') # 读取数据集mat文件数据
laidb = data['laidb'] # 图标签(0和1),输出为list数据类型
aidb = data['aidb'] # 图数据,list数据类型
nl = aidb['nl'] # 节点及节点特征,[20,1]
am = aidb['am'] # 节点和边的邻接矩阵[20,20]
# no = aidb['no'] # 图的编号,也就是第几张图
dataset = [] # data数据对象的list集合
for i in range(len(laidb)):
# 数据转换
# 邻接矩阵转换成COO稀疏矩阵及转换
# am = np.array(am[i]) # 无所谓,list先转换成numpy
edge_index_temp = sp.coo_matrix(am[i])
indices = np.vstack((edge_index_temp.row, edge_index_temp.col))
edge_index = torch.LongTensor(indices)
# 节点及节点特征数据转换[20,1]
x =