很早以前就想研究一下怎么用PyG,现在终于有一点时间了,开更! 、
一、安装
安装最麻烦的是对齐所有东西的版本,尤其是安装
这里涉及到的有 python-pytorch-cuda-cudnn-PyG相关框架
1.1 创建虚拟环境,安装pytorch
官方安装手册Start Locally | PyTorch
当中有提到python版本3.8-3.11现在基本上支持大多数的pytorch 版本了
Python 3.8-3.11 is generally installed by default on any of our supported Linux distributions, which meets our recommendation.
所以直接用Python3.9创建虚拟环境
# 创建虚拟环境
conda create -n pyg python=3.9
# 进入虚拟环境
conda activate pyg
# 查看所有虚拟环境
conda env list
虚拟环境创建好以后,查看cuda版本,在官网上选择出对应的安装指令
但是官方指令常常会有安装缓慢的时候,具体用pip还是conda可以换着试试,换源也是常用的方法
1.2 安装PyG
官网安装手册Installation — pytorch_geometric documentation
这里我选择的是安装stable版本,最好是把相关的依赖一起安装上
二、导入PyG读取数据集
2.1 Cora数据集介绍
PyG内置了几种常用的数据集,这里主要用到Cora数据集,解决简单的分类问题
Cora数据集是一个机器学习论文数据集,统计了2078篇文章,内含.content、.cites两个文件
.content:通过统计机器学习不同领域中的key words在每篇论文中的出现,给出论文所属的分类,共有7个label
.cites: 记录了论文之间的引用关系,比如"paper1:paper2"代表有向图中的链路"paper2->paper1”,应该是paper2引用了paper1
TODO: 构造一个分类器,利用.cites中的论文引用关系判断论文所属分类,以.content中给出的label为基准,通过可视化、构造混淆矩阵等方式评判分类器的性能
Cora数据集虽然只有一张图,但是充分使用节点之间的连接关系构造节点特征,是图神经网络入门必不可少的数据集之一
2.2 PyG初步读取数据集
PyG的基础教程中给了一段读取Cora数据集的代码,参考Introduction by Example — pytorch_geometric documentation
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
>>> Cora()
len(dataset)
>>> 1
dataset.num_classes
>>> 7
dataset.num_node_features
>>> 1433
data = dataset[0]
>>> Data(edge_index=[2, 10556], test_mask=[2708],
train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
data.is_undirected()
>>> True
data.train_mask.sum().item()
>>> 140
data.val_mask.sum().item()
>>> 500
data.test_mask.sum().item()
>>> 1000
这里的data为每个节点分配了label,并有额外的node-level属性:
1.train_mask
: denotes against which nodes to train (140 nodes)
2.val_mask
: denotes which nodes to use for validation, e.g., to perform early stopping (500 nodes)
3.test_mask
: denotes against which nodes to test (1000 nodes).
2.3 初步可视化
这里参考Pytorch Geometric 系列教程1:互动可视化Graph数据集 - MyEncyclopedia
将 cora 转换成 networkx 格式,Cora 有 7 种节点类型,将每种节点类型赋予不同颜色,调用 networkx 的 spring_layout 计算每个节点的弹簧布局下的位置
完整代码如下:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
name_data = 'Cora'
dataset = Planetoid(root='./data/', name=name_data)
from torch_geometric.utils import to_networkx
cora = to_networkx(dataset.data)
print(cora.is_directed())
node_classes = dataset.data.y.data.numpy()
print(node_classes)
node_color = ["red","blue","green","yellow","peru","violet","cyan"]
node_label = np.array(list(cora.nodes))
import matplotlib.pyplot as plt
import networkx as nx
pos = nx.layout.spring_layout(cora)
plt.figure(figsize=(16,12))
for i in np.arange(len(np.unique(node_classes))):
node_list = node_label[node_classes == i]
nx.draw_networkx_nodes(cora, pos, nodelist=list(node_list),
node_size=50,
node_color=node_color[i],
alpha=0.8)
nx.draw_networkx_edges(cora, pos,width=1,edge_color="black")
plt.show()