geometric库的数据处理详解

下面所有博客是个人对EEG脑电的探索,项目代码是早期版本不完整,需要完整项目代码和资料请私聊。


数据集
1、脑电项目探索和实现(EEG) (上):研究数据集选取和介绍SEED
相关论文阅读分析:
1、EEG-SEED数据集作者的—基线论文阅读和分析
2、图神经网络EEG论文阅读和分析:《EEG-Based Emotion Recognition Using Regularized Graph Neural Networks》
3、EEG-GNN论文阅读和分析:《EEG Emotion Recognition Using Dynamical Graph Convolutional Neural Networks》
4、论文阅读和分析:Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification
5、论文阅读和分析:《DeepGCNs: Can GCNs Go as Deep as CNNs?》
6、论文阅读和分析: “How Attentive are Graph Attention Networks?”
7、论文阅读和分析:Simplifying Graph Convolutional Networks

8、论文阅读和分析:LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation
9、图神经网络汇总和总结
相关实验和代码实现:
1、用于图神经网络的脑电数据处理实现_图神经网络 脑电
2、使用GCN训练和测试EEG的公开SEED数据集
3、使用GAT训练和测试EEG公开的SEED数据集
4、使用SGC训练和测试SEED数据集
5、使用Transformer训练和测试EEG的公开SEED数据集_eeg transformer
6、使用RGNN训练和测试EEG公开的SEED数据集
辅助学习资料:
1、官网三个简单Graph示例说明三种层次的应用_graph 简单示例
2、PPI数据集示例项目学习图神经网络
3、geometric库的数据处理详解
4、NetworkX的dicts of dicts以及解决Seven Bridges of Königsberg问题
5、geometric源码阅读和分析:MessagePassin类详解和使用
6、cora数据集示例项目学习图神经网络
7、Graph 聚合
8、QM9数据集示例项目学习图神经网络
9、处理图的开源库

geometric库的数据处理详解

参考:torch_geometric.data — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)

数据对象结构

数据对象含义
DataA data object describing a homogeneous graph. 同构图
HeteroDataA data object describing a heterogeneous graph, holding multiple node and/or edge types in disjunct storage objects.异构图,并且在分离的存储对象有多节点或者边类型
BatchA data object describing a batch of graphs as one big (disconnected) graph.使用一个大的不连通图表示图的batch
TemporalDataA data object composed by a stream of events describing a temporal graph.由描述时间图的事件流组成的数据对象。
DatasetDataset base class for creating graph datasets.基本数据集
InMemoryDatasetDataset base class for creating graph datasets which easily fit into CPU memory.内存数据集,容易放到内存的图,速度快

远端后端接口

对象含义
FeatureStoreAn abstract base class to access features from a remote feature store.从远端特征存储处访问特征
GraphStoreAn abstract base class to access edges from a remote graph store.从远端存储处访问边
TensorAttrDefines the attributes of a FeatureStore tensor.定义FeatureStore的属性
EdgeAttrDefines the attributes of a GraphStore edge.定义GraphStore edge的属性

pytorch轻量级wrapper

可以将geometric表示的对象转换成可以进行三种层次的multi-GPU训练的pytorch对象;

对象含义
LightningDatasetConverts a set of Dataset objects into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU graph-level training via PyTorch Lightning.转换Dataset对象到LightningDataModule,LightningDataModule可以自动被作为使用pytorch的多GPU 图层次训练的datamodule
LightningNodeDataConverts a Data or HeteroData object into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU node-level training via PyTorch Lightning.转换Data或者HeteroData对象到LightningDataModule,LightningDataModule可以自动被作为使用pytorch的多GPU 节点层次训练的datamodule
LightningLinkDataConverts a Data or HeteroData object into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU link-level training (such as for link prediction) via PyTorch Lightning.转换Data或者HeteroData对象到LightningDataModule,LightningDataModule可以自动被作为使用pytorch的多GPU 边层次训练的datamodule

工具函数

下载和解压函数,很方便使用;

函数含义
makedirsRecursively creates a directory.
download_urlDownloads the content of an URL to a specific folder.
extract_tarExtracts a tar archive to a specific folder.
extract_zipExtracts a zip archive to a specific folder.
extract_bz2Extracts a bz2 archive to a specific folder.
extract_gzExtracts a gz archive to a specific folder.

Data

torch_geometric.data.Data — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)

from torch_geometric.data import Data

data = Data(x=x, edge_index=edge_index, ...)

# Add additional arguments to `data`:
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)

# Analyzing the graph structure:
data.num_nodes
>>> 23

data.is_directed()
>>> False

# PyTorch tensor functionality:
data = data.pin_memory()
data = data.to('cuda:0', non_blocking=True)

HeteroData

torch_geometric.data.HeteroData — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)

注意有多种方法可以创建异构图

方法1:

from torch_geometric.data import HeteroData

data = HeteroData()

# Create two node types "paper" and "author" holding a feature matrix:
data['paper'].x = torch.randn(num_papers, num_paper_features)
data['author'].x = torch.randn(num_authors, num_authors_features)

# Create an edge type "(author, writes, paper)" and building the
# graph connectivity:
data['author', 'writes', 'paper'].edge_index = ...  # [2, num_edges]

data['paper'].num_nodes
>>> 23

data['author', 'writes', 'paper'].num_edges
>>> 52

# PyTorch tensor functionality:
data = data.pin_memory()
data = data.to('cuda:0', non_blocking=True)

方法2:

from torch_geometric.data import HeteroData

data = HeteroData()
data['paper'].x = x_paper

data = HeteroData(paper={ 'x': x_paper })

data = HeteroData({'paper': { 'x': x_paper }})

方法3:

data = HeteroData()
data['author', 'writes', 'paper'].edge_index = edge_index_author_paper

data = HeteroData(author__writes__paper={
    'edge_index': edge_index_author_paper
})

data = HeteroData({
    ('author', 'writes', 'paper'):
    { 'edge_index': edge_index_author_paper }
})

Batch

DataLoader会返回一个Batch

将一批图描述为一个大的(断开连接的)图。从torch_geometric.data.data或torch_giometric.data.HtereoData继承。此外,可以通过赋值向量批处理来标识单个图形,该批处理将每个节点映射到其各自的图形标识符。从Data或HeteroData对象的Python列表构造Batch对象。赋值向量批是动态创建的。此外,为follow_batch中的每个键创建赋值向量。将排除exclude_keys中给定的任何键。

原理:

神经网络通常以批处理方式进行训练。PyG通过创建稀疏块对角邻接矩阵(由edge_index定义)并在节点维度中连接特征矩阵和目标矩阵,在小批量上实现并行化。这种组合允许在一个批次中的示例上有不同数量的节点和边:
A = [ A 1 ⋱ A n ] , X = [ X 1 ⋮ X n ] , Y = [ Y 1 ⋮ Y n ] \begin{split}\mathbf{A} = \begin{bmatrix} \mathbf{A}_1 & & \\ & \ddots & \\ & & \mathbf{A}_n \end{bmatrix}, \qquad \mathbf{X} = \begin{bmatrix} \mathbf{X}_1 \\ \vdots \\ \mathbf{X}_n \end{bmatrix}, \qquad \mathbf{Y} = \begin{bmatrix} \mathbf{Y}_1 \\ \vdots \\ \mathbf{Y}_n \end{bmatrix}\end{split} A= A1An ,X= X1Xn ,Y= Y1Yn

从Data或HeteroData对象的Python列表构造Batch对象。

def from_data_list(data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None)

Dataset 和 InMemoryDataset

PARAMETERS

  • root (str, optional) – Root directory where the dataset should be saved. (optional: None)
  • transform (callable*,* optional) – A function/transform that takes in an Data object and returns a transformed version. The data object will be transformed before every access. (default: None)
  • pre_transform (callable*,* optional) – A function/transform that takes in an Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)
  • pre_filter (callable*,* optional) – A function that takes in an Data object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)
  • log (bool, optional) – Whether to print any console output while downloading and processing the dataset. (default: True)

基本框架:

基本框架实现的是一种可以使用pre_transform 、pre_filter的方式,可以节约资源。而且父类InMemoryDataset实现__iter__函数,可以通过下标[index]来访问单条数据。

import torch
from torch_geometric.data import InMemoryDataset, download_url


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.
        download_url(url, self.raw_dir)
        ...

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

但是对于大的数据集,不能直接放到内存,需要放到磁盘,使用Dataset作为父类

import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        idx = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

TemporalData

由描述时间图的事件流组成的数据对象。TemporalData对象可以保存带有结构化消息的事件列表(可以理解为图中的时间边)。事件由源节点、目标节点、时间戳和消息组成。任何连续时间动态图(CTDG)都可以用这四个值表示。

一般来说,TemporalData试图模仿常规Python字典的行为。此外,它为分析图形结构提供了有用的功能,并提供了基本的PyTorch张量功能。

class TemporalData(src: Optional[Tensor] = None, dst: Optional[Tensor] = None, t: Optional[Tensor] = None, msg: Optional[Tensor] = None, **kwargs)[source]
from torch import Tensor
from torch_geometric.data import TemporalData

events = TemporalData(
    src=Tensor([1,2,3,4]),
    dst=Tensor([2,3,4,5]),
    t=Tensor([1000,1010,1100,2000]),
    msg=Tensor([1,1,0,0])
)

# Add additional arguments to `events`:
events.y = Tensor([1,1,0,0])

# It is also possible to set additional arguments in the constructor
events = TemporalData(
    ...,
    y=Tensor([1,1,0,0])
)

# Get the number of events:
events.num_events
>>> 4

# Analyzing the graph structure:
events.num_nodes
>>> 5

# PyTorch tensor functionality:
events = events.pin_memory()
events = events.to('cuda:0', non_blocking=True)

lightning.LightningDataset

将一组数据集对象转换为pytorch_lightning.LightningDataModule变量,该变量可通过PyTorchLightning自动用作多GPU图形级训练的数据模块。LightningDataset将负责通过DataLoader提供小批量。

目前,仅支持pytorch_lightning.strategies.SingleDeviceStrategy和pytorch_lightning.sstrategies.DDP pytorch lightning的SpawnStrategy培训策略,以便在所有设备/进程之间正确共享数据:

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)

lightning.LightningNodeData

将Data或HeteroData对象转换为pytorch_lightning.LightningDataModule变量,该变量可以自动用作通过pytorch lightning进行多GPU节点级训练的数据模块。LightningDataset将负责通过邻居加载器提供小批量。

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)

lightning.LightningLinkData

将Data或HeteroData对象转换为pytorch_lightning.LightningDataModule变量,该变量可通过PyTorchLightning自动用作多GPU链路级训练(如链路预测)的数据模块。LightningDataset将负责通过LinkNeighborLoader提供小批量。

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KPer_Yang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值