1、环境配置
colab 是google集成的深度学习开发环境,tensorflow、pytorch等软件已经配置。在刚开始配置时,使用了默认的torch版本和cuda版本,导致安装过程异常缓慢,需要一个多小时。经过Datawhale中同学的帮助,经过多次实验,应该是cuda版本的问题。
在colab中cuda版本是cu10.1,如果配置为cu11.1,则可以正常安装,安装过程如下:
!pip3 install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
!pip install torch-geometric
2、data类
Data类是PyG中图操作的基础类,可以用来调用torch_geometric中的内置数据集,也可以根据需要构造问题需要的数据集。在data类基础上构造类见第三部分。
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/dataset/Cora', name='Cora')
# Cora()
print(len(dataset))
# 1
print(dataset.num_classes)
# 7
print(dataset.num_node_features)
# 1433
3、请通过继承Data
类实现一个类,专门用于表示“机构-作者-论文”的网络。该网络包含“机构“、”作者“和”论文”三类节点,以及“作者-机构“和“作者-论文“两类边。对要实现的类的要求:1)用不同的属性存储不同节点的属性;2)用不同的属性存储不同的边(边没有属性);3)逐一实现获取不同节点数量的方法。
from torch_geometric.data.data import Data
# 请通过继承Data 类实现⼀个类,专⻔⽤于表示“机构-作者-论⽂”的⽹络。该⽹络包含“机构“、”作者“和”论⽂”
# 三类节点,以及“作者-机构“和“作者-论⽂“两类边。对要实现的类的要求:1)⽤不同的属性存储不同节点的属
# 性;2)⽤不同的属性存储不同的边(边没有属性);3)逐⼀实现获取不同节点数量的⽅法。
class Network(Data):
def __init__(self, org=None,author=None,article=None, org_author=None,author_article=None,**kwargs):
r"""
Args:
org (Tensor, optional): 节点属性矩阵,⼤⼩为`[num_nodes, num_node_features]`
author (Tensor, optional): 节点属性矩阵,⼤⼩为`[num_nodes, num_node_features]`
article (Tensor, optional): 节点属性矩阵,⼤⼩为`[num_nodes, num_node_features]`
org_author (LongTensor, optional): 边索引矩阵,⼤⼩为`[2, num_edges]`,第0⾏为尾节点,第1⾏为头节点,头指向尾
author_article (LongTensor, optional): 边索引矩阵,⼤⼩为`[2, num_edges]`,第0⾏为尾节点,第1⾏为头节点,头指向尾
"""
super().__init__(**kwargs)
self.org = org
self.author = author
self.article = article
self.org_author = org_author
self.author_article = author_article
for key, item in kwargs.items():
if key == 'num_nodes':
self.__num_nodes__ = item
else:
self[key] = item
@property
def num_org_nodes(self):
if self.org is None:
return 0
else:
return self.org.shape[0]
@property
def num_author_nodes(self):
if self.author is None:
return 0
else:
return self.author.shape[0]
@property
def num_article_nodes(self):
if self.article is None:
return 0
else:
return self.article.shape[0]
单位的数量: 2708
作者数量: 2708
文章数量 2708