通过构造函数
Data
类的构造函数:
class Data(object):
def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, **kwargs):
r"""
Args:
x (Tensor, optional): 节点属性矩阵,大小为`[num_nodes, num_node_features]`
edge_index (LongTensor, optional): 边索引矩阵,大小为`[2, num_edges]`,第0行为尾节点,第1行为头节点,头指向尾
edge_attr (Tensor, optional): 边属性矩阵,大小为`[num_edges, num_edge_features]`
y (Tensor, optional): 节点或图的标签,任意大小(,其实也可以是边的标签)
"""
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y
for key, item in kwargs.items():
if key == 'num_nodes':
self.__num_nodes__ = item
else:
self[key] = item
edge_index
的每一列定义一条边,其中第一行为边起始节点的索引,第二行为边结束节点的索引。这种表示方法被称为COO格式(coordinate format),通常用于表示稀疏矩阵。PyG不是用稠密矩阵
A
∈
{
0
,
1
}
∣
V
∣
×
∣
V
∣
\mathbf{A} \in \{ 0, 1 \}^{|\mathcal{V}| \times |\mathcal{V}|}
A∈{0,1}∣V∣×∣V∣来持有邻接矩阵的信息,而是用仅存储邻接矩阵
A
\mathbf{A}
A中非
0
0
0元素的稀疏矩阵来表示图。
通常,一个图至少包含x, edge_index, edge_attr, y, num_nodes
5个属性,当图包含其他属性时,我们可以通过指定额外的参数使Data
对象包含其他的属性:
graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, num_nodes=num_nodes, other_attr=other_attr)
转dict
对象为Data
对象
我们也可以将一个dict
对象转换为一个Data
对象:
graph_dict = {
'x': x,
'edge_index': edge_index,
'edge_attr': edge_attr,
'y': y,
'num_nodes': num_nodes,
'other_attr': other_attr
}
graph_data = Data.from_dict(graph_dict)
from_dict
是一个类方法:
@classmethod
def from_dict(cls, dictionary):
r"""Creates a data object from a python dictionary."""
data = cls()
for key, item in dictionary.items():
data[key] = item
return data
注意:graph_dict
中属性值的类型与大小的要求与Data
类的构造函数的要求相同。
Data
对象转换成其他类型数据
我们可以将Data
对象转换为dict
对象:
def to_dict(self):
return {key: item for key, item in self}
或转换为namedtuple
:
def to_namedtuple(self):
keys = self.keys
DataTuple = collections.namedtuple('DataTuple', keys)
return DataTuple(*[self[key] for key in keys])
获取Data
对象属性
x = graph_data['x']
设置Data
对象属性
graph_data['x'] = x
获取Data
对象包含的属性的关键字
graph_data.keys()
对边排序并移除重复的边
graph_data.coalesce()
Data
对象的其他性质
我们通过观察PyG中内置的一个图来查看Data
对象的性质:
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
data = dataset[0] # Get the first graph object.
print(data)
print('==============================================================')
# 获取图的一些信息
print(f'Number of nodes: {data.num_nodes}') # 节点数量
print(f'Number of edges: {data.num_edges}') # 边数量
print(f'Number of node features: {data.num_node_features}') # 节点属性的维度
print(f'Number of node features: {data.num_features}') # 同样是节点属性的维度
print(f'Number of edge features: {data.num_edge_features}') # 边属性的维度
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') # 平均节点度
print(f'if edge indices are ordered and do not contain duplicate entries.: {data.is_coalesced()}') # 是否边是有序的同时不含有重复的边
print(f'Number of training nodes: {data.train_mask.sum()}') # 用作训练集的节点
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}') # 用作训练集的节点数占比
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}') # 此图是否包含孤立的节点
print(f'Contains self-loops: {data.contains_self_loops()}') # 此图是否包含自环的边
print(f'Is undirected: {data.is_undirected()}') # 此图是否是无向图
四、Dataset
类——PyG中图数据集的表示及其使用
PyG内置了大量常用的基准数据集,接下来我们以PyG内置的Planetoid
数据集为例,来学习PyG中图数据集的表示及使用。
Planetoid
数据集类的官方文档为torch_geometric.datasets.Planetoid。
生成数据集对象并分析数据集
如下方代码所示,在PyG中生成一个数据集是简单直接的。在第一次生成PyG内置的数据集时,程序首先下载原始文件,然后将原始文件处理成包含Data
对象的Dataset
对象并保存到文件。
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/dataset/Cora', name='Cora')
# Cora()
len(dataset)
# 1
dataset.num_classes
# 7
dataset.num_node_features
# 1433
分析数据集中样本
可以看到该数据集只有一个图,包含7个分类任务,节点的属性为1433维度。
data = dataset[0]
# Data(edge_index=[2, 10556], test_mask=[2708],
# train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
data.is_undirected()
# True
data.train_mask.sum().item()
# 140
data.val_mask.sum().item()
# 500
data.test_mask.sum().item()
# 1000
现在我们看到该数据集包含的唯一的图,有2708个节点,节点特征为1433维,有10556条边,有140个用作训练集的节点,有500个用作验证集的节点,有1000个用作测试集的节点。
作业
- 请通过继承
Data
类实现一个类,专门用于表示“机构-作者-论文”的网络。该网络包含“机构“、”作者“和”论文”三类节点,以及“作者-机构“和“作者-论文“两类边。对要实现的类的要求:1)用不同的属性存储不同节点的属性;2)用不同的属性存储不同的边(边没有属性);3)逐一实现获取不同节点数量的方法。
from torch_geometric.data import Data
class AcademicData(Data):
def __init__(self,author_feat=None,org_feat=None,paper_feat=None,a_o_edge=None,a_p_edge=None,num_authors=None,num_papers=None,num_orgs=None):
super().__init__()
self.author_feat = author_feat
self.org_feat = org_feat
self.paper_feat = paper_feat
self.a_o_edge = a_o_edge
self.a_p_edge = a_p_edge
self.num_authors = num_authors
self.num_papers = num_papers
self.num_orgs = num_orgs
@property
def num_nodes(self):
return self.num_authors + self.num_papers + self.num_orgs
from AcademicData import AcademicData
import numpy as np
import torch
def main():
num_author = 1000
num_org = 200
num_paper = 50000
author_feat = torch.randn(num_author,64)
org_feat = torch.randn(num_org,64)
paper_feat = torch.randn(num_paper,64)
a_o_edge = torch.from_numpy(np.random.randint(0,2,(num_author,num_org)))
a_p_edge = torch.from_numpy(np.random.randint(0,2,(num_author,num_paper)))
graph_dict = {
'author_feat': author_feat,
'org_feat': org_feat,
'paper_feat': paper_feat,
'a_o_edge': a_o_edge,
'a_p_edge': a_p_edge,
'edge_attr': None,
'num_authors': num_author,
'num_papers': num_org,
'num_orgs': num_paper,
'x':None,
'y':None
}
graph=AcademicData(graph_dict['author_feat'],graph_dict['org_feat'],
graph_dict['paper_feat'],graph_dict['a_o_edge'],graph_dict['a_p_edge'],
graph_dict['num_authors'],graph_dict['num_orgs'],graph_dict['num_papers'])
# 获取图的一些信息
print(f'Number of authors: {graph.num_authors}') # 节点数量
print(f'Number of orgs: {graph.num_orgs}') # 节点数量
print(f'Number of papers: {graph.num_papers}') # 节点数量
print(f'Number of nodes: {graph.num_nodes}') # 总节点数量
print(graph)
if __name__ == "__main__":
main()
参考
https://github.com/datawhalechina/team-learning-nlp/tree/master/GNN