【DGL教程】第4章 图数据集

本文档详细介绍了如何使用DGL提供的DGLDataset类来创建和处理自定义图数据集,包括下载、处理、保存和加载数据的步骤。DGL推荐通过继承DGLDataset来实现数据集,提供了下载、处理、保存和加载数据的方法模板。此外,还展示了图分类数据集QM7b和顶点分类数据集Citeseer的使用示例。
摘要由CSDN通过智能技术生成

官方文档:https://docs.dgl.ai/en/latest/guide/data.html

dgl.data实现了很多常用的图数据集,这些数据集都是dgl.data.DGLDataset的子类
DGL官方推荐通过继承dgl.data.DGLDataset来实现自己的数据集,从而可以更方便地加载、处理、保存图数据集

1.DGLDataset类

dgl.data.DGLDataset类处理数据集的流程包括以下几步:下载、处理、保存到磁盘、从磁盘加载,如下图所示
图数据集流程图

自定义数据集类:

from dgl.data import DGLDataset

class MyDataset(DGLDataset):
    def __init__(self):
        super().__init__(name='my_dataset', url='https://example.com/path/to/my_dataset.zip')

    def download(self):
        # download raw data to local disk
        pass

    def save(self):
        # save processed data to directory `self.save_path`
        pass

    def load(self):
        # load processed data from directory `self.save_path`
        pass

    def process(self):
        # process raw data to graphs, labels, splitting masks
        pass

    def has_cache(self):
        # check whether there are processed data in `self.save_path`
        pass

    def __getitem__(self, idx):
        # get one example by index
        pass

    def __len__(self):
        # number of data examples
        pass

其中process(), __getitem__(idx)__len__()是必须实现的方法

DGLDataset类的目的是提供一种标准、方便的加载图数据的方式,可以存储图、特征、标签、划分以及数据集的其他基本信息(如类别数)
下面介绍实现数据集中方法的最佳实践

1.1 下载原始数据

DGLDataset.download()方法用于从self.url指定的URL下载原始数据,并保存到self.raw_dir目录

  • DGL提供了一个辅助函数dgl.data.utils.download()用于从指定的URL下载文件
  • DGLDatasetraw_dir属性是原始数据下载目录,如果在构造函数中指定了raw_dir参数则使用指定的目录,如果未指定则使用环境变量DGL_DOWNLOAD_DIR指定的目录,如果该环境变量不存在则默认为~/.dgl
  • DGLDatasetraw_path属性是os.path.join(self.raw_dir, self.name)(也可以在子类中覆盖),可用于原始数据的解压目录(如果原始数据是zip文件)

示例:

def download(self):
    zip_file_path = os.path.join(self.raw_dir, 'my_dataset.zip')
    download(self.url, path=zip_file_path)
    extract_archive(zip_file_path, self.raw_path)

1.2 处理数据

DGLDataset.process()方法用于将self.raw_dirself.raw_path中的原始数据处理成DGLGraph的格式,一般包括读取原始数据、数据清洗、构造图、读取顶点特征和标签以及划分数据集等步骤,具体逻辑取决于原始数据的格式(可能是pkl, npz, mat, csv, txt等等),这是需要自己实现的主要部分(也是最麻烦的部分)
以只有一个图的数据集为例(例如顶点分类数据集),基本框架如下:

def process(self):
    data = _read_raw_data(self.raw_path)
    data = _clean(data)
    g = dgl.graph(...)
    g.ndata['feat'] = ...
    g.ndata['label'] = ...
    g.ndata['train_mask'] = ...
    g.ndata['val_mask'] = ...
    g.ndata['test_mask'] = ...
    self.g = g

def __getitem__(self, idx):
    if idx != 0:
        raise IndexError('This dataset has only one graph')
    return self.g

def __len__(self):
    return 1

注:

  • _read_raw_data()_clean()是需要自己实现的读取原始数据和清洗数据的逻辑
  • 该示例只有一个同构图,实际也可能是异构图,也可能包含多个图(例如图分类数据集)

1.3 保存和加载数据

DGL推荐实现数据集的save()load()方法,将预处理完的数据缓存到磁盘,下次使用时可直接从磁盘加载,不需要再执行process()方法,has_cache()方法返回磁盘上是否有缓存的已处理好的数据

DGL提供了4个函数:

  • dgl.save_graphs()dgl.load_graphs()用于向磁盘保存/从磁盘读取DGLGraph对象
  • dgl.data.utils.save_info()dgl.data.utils.load_info()用于向磁盘保存/从磁盘读取数据集的相关信息(实际上就是pickle.dump()pickle.load()

保存路径:

  • DGLDatasetsave_dir属性是处理好的数据的保存目录,如果在构造函数中指定了save_dir参数则使用指定的目录,否则默认为raw_dir
  • DGLDatasetsave_path属性是os.path.join(self.save_dir, self.name)(也可以在子类中覆盖),一般将处理好的数据保存到save_path目录下

典型用法:
(1)顶点分类数据集(只有一个图)

def save(self):
    # save graphs and labels
    graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
    save_graphs(graph_path, [self.g])

def load(self):
    # load processed data from directory `self.save_path`
    graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
    graphs, _ = load_graphs(graph_path)
    self.labels = label_dict['labels']
    self.g = graphs[0]

def has_cache(self):
    # check whether there are processed data in `self.save_path`
    graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
    return os.path.exists(graph_path)

(2)图分类数据集(包含多个图)

def save(self):
    # save graphs and labels
    graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
    save_graphs(graph_path, self.graphs, {'labels': self.labels})
    # save other information in python dict
    info_path = os.path.join(self.save_path, self.name + '_info.pkl')
    save_info(info_path, {'num_classes': self.num_classes})

def load(self):
    # load processed data from directory `self.save_path`
    graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
    self.graphs, label_dict = load_graphs(graph_path)
    self.labels = label_dict['labels']
    info_path = os.path.join(self.save_path, self.name + '_info.pkl')
    self.num_classes = load_info(info_path)['num_classes']

def has_cache(self):
    # check whether there are processed data in `self.save_path`
    graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
    info_path = os.path.join(self.save_path, self.name + '_info.pkl')
    return os.path.exists(graph_path) and os.path.exists(info_path)

2.使用图数据集

2.1 图分类数据集

图分类数据集和传统机器学习的数据集类似,包含了一组样本和对应的标签,只是每个样本是一个dgl.DGLGraph,标签是一个张量,样本的特征保存在不同的顶点特征或边特征中

下面以QM7b数据集为例演示使用方法
创建数据集

>>> from dgl.data import QM7bDataset
>>> qm7b = QM7bDataset()  # 首次使用会先下载数据集
>>> len(qm7b)
7211
>>> qm7b.num_labels
14
>>> g, label = qm7b[0]
>>> g
Graph(num_nodes=5, num_edges=25,
      ndata_schemes={}
      edata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)})
>>> g.edata
{'h': tensor([[36.8581],
        [ 2.8961],
        ...
        [ 0.5000]])}
>>> label
tensor([-4.2093e+02,  3.9695e+01,  6.2184e-01, -1.6013e+01,  4.1620e+00,
         3.6768e+01,  1.5725e+01, -3.9861e+00, -1.0949e+01,  1.3230e-01,
        -1.4134e+01,  1.0870e+00,  2.5346e+00,  2.4322e+00])

可以看到该数据集共有7211个样本,每个样本有14个标签(对应14个预测任务),第1个样本图有5个顶点、25条边,有一个名为h的边特征,维数为1

遍历数据集
可以使用PyTorch的DataLoader遍历数据集

from torch.utils.data import DataLoader

# load data
dataset = QM7bDataset()
num_labels = dataset.num_labels

# create collate_fn
def _collate_fn(batch):
    graphs, labels = batch
    g = dgl.batch(graphs)
    labels = torch.tensor(labels, dtype=torch.long)
    return g, labels

# create dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn)

# training
for epoch in range(100):
    for g, labels in dataloader:
        # your training code here
        pass

2.2 顶点分类数据集

顶点分类通常只在一个图上进行,因此这类数据集只有一个图,样本特征和标签保存在顶点特征中
以Citeseer数据集为例,该数据集包含一个图,有3327个顶点、9228条边,特征、标签、训练集、验证集、测试集掩码分别在顶点特征feat, label, train_mask, val_mask, test_mask中,顶点特征为3703维,6个类别(标签范围为[0, 5])

>>> from dgl.data import CiteseerGraphDataset
>>> citeseer = CiteseerGraphDataset()
>>> len(citeseer)
1
>>> citeseer.num_classes
6
>>> g = citeseer[0]
>>> g
Graph(num_nodes=3327, num_edges=9228,
      ndata_schemes={'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'label': Scheme(shape=(), dtype=torch.int64), 'feat': Scheme(shape=(3703,), dtype=torch.float32)}
      edata_schemes={})
>>> g.ndata['feat'].shape
torch.Size([3327, 3703])
>>> g.ndata['label'].shape
torch.Size([3327])
>>> g.ndata['label'][:10]
tensor([3, 1, 5, 5, 3, 1, 3, 0, 3, 5])
>>> train_idx = torch.nonzero(g.ndata['train_mask']).squeeze()
>>> train_set = g.ndata['feat'][train_idx]
>>> train_set.shape
torch.Size([120, 3703])

2.3 连接预测数据集

连接预测数据集和顶点分类数据集类似,也只有一个图,但训练集、验证集、测试集掩码在边特征中,这类数据集有dgl.data.KnowledgeGraphDataset的几个子类

2.4 OGB数据集

Open Graph Benchmark (OGB): https://ogb.stanford.edu/docs/home/

构建包含100个数据集,可以按照以下步骤进行: 1. 定义结构:使用DGL中的Graph对象定义每个的结构,包括节点数、边数、节点和边的特征等。 2. 添加节点和边特征:使用DGL中的NodeDataLoader和EdgeDataLoader等数据加载器为节点和边添加特征信息。 3. 构建多个:使用Python的循环语句,根据定义好的结构和特征信息创建多个对象。 4. 划分数据集:使用DGL中的train_test_split_edges函数将每个划分为训练集、验证集和测试集。 5. 批量化数据:使用DGL中的GraphDataLoader函数将处理好的数据批量化,以便于输入模型进行训练和推理。 以下是一个简单的Python代码示例,用于构建包含100个数据集,并为每个添加节点和边特征: ``` import dgl import torch from dgl.data import DGLDataset from dgl.dataloading import GraphDataLoader from sklearn.model_selection import train_test_split class MyDataset(DGLDataset): def __init__(self): super().__init__(name='my_dataset') def process(self): # 定义结构 g = dgl.graph(([0, 1], [1, 0])) # 两个节点和一条边 # 添加节点特征 g.ndata['x'] = torch.tensor([[1.], [2.]]) # 添加边特征 g.edata['w'] = torch.tensor([3.]) # 构建多个 self.graphs = [g] * 100 # 划分数据集 train_graphs, valid_graphs, test_graphs = dgl.random.split_dataset(self.graphs, [0.6, 0.2, 0.2]) # 批量化数据 self.train_loader = GraphDataLoader(train_graphs, batch_size=1, shuffle=True) self.valid_loader = GraphDataLoader(valid_graphs, batch_size=1, shuffle=False) self.test_loader = GraphDataLoader(test_graphs, batch_size=1, shuffle=False) def __getitem__(self, idx): return self.graphs[idx] def __len__(self): return len(self.graphs) dataset = MyDataset() dataset.process() ``` 以上代码示例中,MyDataset类继承自DGLDataset类,实现了__init__、process、__getitem__和__len__等方法。在process方法中,先定义了一个包含两个节点和一条边的,并为节点和边添加了特征信息。然后使用Python的循环语句,根据该结构和特征信息创建了100个对象,并使用train_test_split函数将每个划分为训练集、验证集和测试集。最后使用GraphDataLoader函数将处理好的数据批量化,以便于输入模型进行训练和推理。 需要注意的是,以上代码示例仅用于说明构建包含100个数据集的基本步骤,实际应用中需要根据具体任务进行相应的修改。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值