0.import torch_geometric 的Data 查看_冬炫的博客-CSDN博客_import torch_geometric
1. import torch_geometric 加载一些常见数据集_冬炫的博客-CSDN博客_torch_geometric 数据集
2. torch_geometric mini batch 的那些事_冬炫的博客-CSDN博客
Mini-batches
PyG 创建稀疏的块对角连接矩阵(defined by edge_index
) , 在节点层次拼接节点和标签的特征.
所以在one_batch 中 的节点个数是不同的。这与以前的batch不同,以前都是切蛋糕均分。这个库输入的每个batch 的节点总数都不同
这种特殊的处理mini batch 的方式,用另一个特殊的类:torch_geometric.loader.DataLoader,
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
batch
>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
batch.num_graphs
>>> 32
他把图数据进行了拼接,最终32个图数据,拼接成一个拥有总1082个节点,21维度的特征,4066个边的图batch 数据
torch_geometric.data.Batch 类继承 torch_geometric.data.Data 并包含额外的属性 指针数组,指定每个节点它所在的图编号:batch
.
batch=[0⋯01⋯n−2n−1⋯n−1]⊤(也就是0,...,31)
计算每个图的节点特征的各维度平均值
from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for data in loader:
data
>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
data.num_graphs
>>> 32
x = scatter_mean(data.x, data.batch, dim=0)
x.size()
>>> torch.Size([32, 21])
You can learn more about the internal batching procedure of PyG, e.g., how to modify its behaviour, here. For documentation of scatter operations, we refer the interested reader to the torch-scatter
documentation.
Data Transforms
数据预处理方法(数据增强,数据变换),也可以链接多个预处理方法,类比与图片的操作,先crop, 归一化等等。
Transforms are a common way in torchvision
to transform images and perform augmentation. PyG comes with its own transforms, which expect a Data object as input and return a new transformed Data object. Transforms can be chained together using torch_geometric.transforms.Compose and are applied before saving a processed dataset on disk (pre_transform
) or before accessing a graph in a dataset (transform
).
transforms on the ShapeNet dataset (containing 17,000 3D shape point clouds and per point labels from 16 shape categories).
下面感觉是抽取了Airplane 的样本集合
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])
dataset[0]
>>> Data(pos=[2518, 3], y=[2518])
We can convert the point cloud dataset into a graph dataset by generating nearest neighbor graphs from the point clouds via transforms:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6))
dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
Note
We use the pre_transform
to convert the data before saving it to disk (leading to faster loading times). Note that the next time the dataset is initialized it will already contain graph edges, even if you do not pass any transform. If the pre_transform
does not match with the one from the already processed dataset, you will be given a warning.
In addition, we can use the transform
argument to randomly augment a Data object, e.g., translating each node position by a small number:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6),
transform=T.RandomTranslate(0.01))
dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
You can find a complete list of all implemented transforms at torch_geometric.transforms.