torch_geometric mini batch 的那些事

 0.import torch_geometric 的Data 查看_冬炫的博客-CSDN博客_import torch_geometric

1. import torch_geometric 加载一些常见数据集_冬炫的博客-CSDN博客_torch_geometric 数据集

2. torch_geometric mini batch 的那些事_冬炫的博客-CSDN博客 

3. import torch_geometric 第一个图网络例子_冬炫的博客-CSDN博客 

4. torch_geometric message passing network_冬炫的博客-CSDN博客 

 


Mini-batches

PyG 创建稀疏的块对角连接矩阵(defined by edge_index) , 在节点层次拼接节点和标签的特征. 

所以在one_batch 中 的节点个数是不同的。这与以前的batch不同,以前都是切蛋糕均分。这个库输入的每个batch 的节点总数都不同

这种特殊的处理mini batch 的方式,用另一个特殊的类:torch_geometric.loader.DataLoader

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    batch
    >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

    batch.num_graphs
    >>> 32

他把图数据进行了拼接,最终32个图数据,拼接成一个拥有总1082个节点,21维度的特征,4066个边的图batch 数据 

torch_geometric.data.Batch 类继承 torch_geometric.data.Data 并包含额外的属性 指针数组,指定每个节点它所在的图编号:batch.

batch=[0⋯01⋯n−2n−1⋯n−1]⊤(也就是0,...,31)

计算每个图的节点特征的各维度平均值 

from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for data in loader:
    data
    >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

    data.num_graphs
    >>> 32

    x = scatter_mean(data.x, data.batch, dim=0)
    x.size()
    >>> torch.Size([32, 21])

You can learn more about the internal batching procedure of PyG, e.g., how to modify its behaviour, here. For documentation of scatter operations, we refer the interested reader to the torch-scatterdocumentation.

Data Transforms

数据预处理方法(数据增强,数据变换),也可以链接多个预处理方法,类比与图片的操作,先crop, 归一化等等。 

Transforms are a common way in torchvision to transform images and perform augmentation. PyG comes with its own transforms, which expect a Data object as input and return a new transformed Data object. Transforms can be chained together using torch_geometric.transforms.Compose and are applied before saving a processed dataset on disk (pre_transform) or before accessing a graph in a dataset (transform).

 transforms on the ShapeNet dataset (containing 17,000 3D shape point clouds and per point labels from 16 shape categories).

下面感觉是抽取了Airplane 的样本集合 

from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])

dataset[0]
>>> Data(pos=[2518, 3], y=[2518])

We can convert the point cloud dataset into a graph dataset by generating nearest neighbor graphs from the point clouds via transforms:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))

dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

Note

We use the pre_transform to convert the data before saving it to disk (leading to faster loading times). Note that the next time the dataset is initialized it will already contain graph edges, even if you do not pass any transform. If the pre_transform does not match with the one from the already processed dataset, you will be given a warning.

In addition, we can use the transform argument to randomly augment a Data object, e.g., translating each node position by a small number:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6),
                    transform=T.RandomTranslate(0.01))

dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

You can find a complete list of all implemented transforms at torch_geometric.transforms.

torch_geometric.loader.DataLoader是PyG中的一个类,用于加载和处理图数据。它可以将多个图批处理成单个巨型图,并提供了一些方便的功能。\[2\] 您可以使用torch_geometric.loader.DataLoader来加载和处理图数据集。例如,您可以创建一个包含torch_geometric.data.Data对象的常规Python列表,并将其传递给DataLoader来批处理这些图数据。\[1\] DataLoader还可以接受一些参数,例如batch_size和shuffle,以控制批处理的大小和数据的顺序。您还可以使用其他可以传递给PyTorch DataLoader的参数,例如num_workers。\[2\] 使用DataLoader加载图数据集的示例代码如下:\[3\] ```python from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True) loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: # 在这里对批处理的图数据进行处理 # 例如,计算每个图的节点维度中的平均节点特征 x = scatter_mean(batch.x, batch.batch, dim=0) print(x.size()) # 输出每个图的节点特征的大小 ``` 通过使用torch_geometric.loader.DataLoader,您可以方便地加载和处理图数据集。它提供了一种简单而有效的方式来处理大规模的图数据。\[3\] #### 引用[.reference_title] - *1* *3* [【PyG】文档总结以及项目经验(持续更新](https://blog.csdn.net/weixin_45928096/article/details/125501673)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [第十九课.Pytorch-geometric扩展](https://blog.csdn.net/qq_40943760/article/details/120265255)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值