PyG安装及入门(一)

很早以前就想研究一下怎么用PyG,现在终于有一点时间了,开更!   、

一、安装

安装最麻烦的是对齐所有东西的版本,尤其是安装
这里涉及到的有 python-pytorch-cuda-cudnn-PyG相关框架

1.1 创建虚拟环境,安装pytorch

官方安装手册Start Locally | PyTorch

当中有提到python版本3.8-3.11现在基本上支持大多数的pytorch 版本了

Python 3.8-3.11 is generally installed by default on any of our supported Linux distributions, which meets our recommendation.

    所以直接用Python3.9创建虚拟环境
 

# 创建虚拟环境
conda create -n pyg python=3.9

# 进入虚拟环境
conda activate pyg

# 查看所有虚拟环境
conda env list

虚拟环境创建好以后,查看cuda版本,在官网上选择出对应的安装指令

但是官方指令常常会有安装缓慢的时候,具体用pip还是conda可以换着试试,换源也是常用的方法

1.2 安装PyG

官网安装手册Installation — pytorch_geometric documentation
这里我选择的是安装stable版本,最好是把相关的依赖一起安装上

二、导入PyG读取数据集

2.1 Cora数据集介绍

PyG内置了几种常用的数据集,这里主要用到Cora数据集,解决简单的分类问题
Cora数据集是一个机器学习论文数据集,统计了2078篇文章,内含.content、.cites两个文件
.content:通过统计机器学习不同领域中的key words在每篇论文中的出现,给出论文所属的分类,共有7个label
.cites: 记录了论文之间的引用关系,比如"paper1:paper2"代表有向图中的链路"paper2->paper1”,应该是paper2引用了paper1

TODO: 构造一个分类器,利用.cites中的论文引用关系判断论文所属分类,以.content中给出的label为基准,通过可视化、构造混淆矩阵等方式评判分类器的性能
Cora数据集虽然只有一张图,但是充分使用节点之间的连接关系构造节点特征,是图神经网络入门必不可少的数据集之一

2.2 PyG初步读取数据集

PyG的基础教程中给了一段读取Cora数据集的代码,参考Introduction by Example — pytorch_geometric documentation
 

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')
>>> Cora()

len(dataset)
>>> 1

dataset.num_classes
>>> 7

dataset.num_node_features
>>> 1433

data = dataset[0]
>>> Data(edge_index=[2, 10556], test_mask=[2708],
         train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

data.is_undirected()
>>> True

data.train_mask.sum().item()
>>> 140

data.val_mask.sum().item()
>>> 500

data.test_mask.sum().item()
>>> 1000

这里的data为每个节点分配了label,并有额外的node-level属性:
1.train_mask: denotes against which nodes to train (140 nodes)
2.val_mask: denotes which nodes to use for validation, e.g., to perform early stopping (500 nodes)
3.test_mask: denotes against which nodes to test (1000 nodes).

2.3 初步可视化

这里参考Pytorch Geometric 系列教程1:互动可视化Graph数据集 - MyEncyclopedia

将 cora 转换成 networkx 格式,Cora 有 7 种节点类型,将每种节点类型赋予不同颜色,调用 networkx 的 spring_layout 计算每个节点的弹簧布局下的位置
完整代码如下:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

name_data = 'Cora'
dataset = Planetoid(root='./data/', name=name_data)

from torch_geometric.utils import to_networkx
cora = to_networkx(dataset.data)

print(cora.is_directed())

node_classes = dataset.data.y.data.numpy()
print(node_classes)

node_color = ["red","blue","green","yellow","peru","violet","cyan"] 
node_label = np.array(list(cora.nodes)) 

import matplotlib.pyplot as plt
import networkx as nx

pos = nx.layout.spring_layout(cora)

plt.figure(figsize=(16,12))

for i in np.arange(len(np.unique(node_classes))):
    node_list = node_label[node_classes == i]
    nx.draw_networkx_nodes(cora, pos, nodelist=list(node_list),
                           node_size=50,  
                           node_color=node_color[i], 
                           alpha=0.8)

nx.draw_networkx_edges(cora, pos,width=1,edge_color="black")
plt.show()

  • 35
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值