一、环境配置
因为没有GPU,我这里只安装了CPU版本
- 先安装pytorch
pip install torch torchvision torchaudio
- 安装PyG
pip install torch-scatter
pip install torch-sparse
pip install torch-cluster
pip install torch-spline-conv
pip install torch-geometric
二、Data类
Data类的构造函数如下:
class Data(object):
def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, **kwargs):
r"""
Args:
x (Tensor, optional): 节点属性矩阵,大小为`[num_nodes, num_node_features]`
edge_index (LongTensor, optional): 边索引矩阵,大小为`[2, num_edges]`,第0行为尾节点,第1行为头节点,头指向尾
edge_attr (Tensor, optional): 边属性矩阵,大小为`[num_edges, num_edge_features]`
y (Tensor, optional): 节点或图的标签,任意大小(,其实也可以是边的标签)
"""
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y
for key, item in kwargs.items():
if key == 'num_nodes':
self.__num_nodes__ = item
else:
self[key] = item
Data类的使用:
import torch
from torch_geometric.data import Data
# 节点数据
x = torch.tensor([[2, 1], [5, 6], [3, 7], [12, 0]], dtype=torch.float)
# 节点类型
y = torch.tensor([0, 1, 0, 1], dtype=torch.int)
# 邻接矩阵 COO格式
edge_index = torch.tensor([[0, 0, 1, 2, 3],
[3, 1, 0, 1, 2]], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
三、Dataset类
生成数据集
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/dataset/Cora', name='Cora')
数据集的使用
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
作业
请通过继承Data
类实现一个类,专门用于表示“机构-作者-论文”的网络。该网络包含“机构“、”作者“和”论文”三类节点,以及“作者-机构“和“作者-论文“两类边。对要实现的类的要求:1)用不同的属性存储不同节点的属性;2)用不同的属性存储不同的边(边没有属性);3)逐一实现获取不同节点数量的方法。
from torch_geometric.data import Data
class MyData(Data):
def __init__(self, author, author_edge, institution, institution_edge, paper, paper_edge, **kwargs):
super().__init__(**kwargs)
self.author = author
self.author_edge = author_edge
self.institution = institution
self.institution_edge = institution_edge
self.paper = paper
self.paper_edge = paper_edge
def get_author_num(self):
return len(self.author)
def get_institution_num(self):
return len(self.institution)
def get_paper_num(self):
return len(self.paper)