torch_geometric笔记:数据集 ENZYMES &Minibatches

         Pytorch Geometric中包含大量的常见基准数据集。在初始化数据集的时候,框架会自动下载数据集的原始文件,并将其处理为Data对象。例如要下载ENZYMES数据集(由600个graph划分为6个类别)

1 下载数据集

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='', name='ENZYMES')

dataset
#ENZYMES(600)

type(dataset)
#torch_geometric.datasets.tu_dataset.TUDataset

len(dataset)
#600
#说明600张图

dataset.num_classes
#6
#图一共有6各不同的类

dataset.num_node_features
#3 每一个节点有三个特征

data = dataset[0]
data
#Data(edge_index=[2, 168], x=[37, 3], y=[1])
#第一张图有168条有向边,37个节点,每个节点3个特征,整张图有一个类别

data.is_undirected()
#True

2 Mini-batches

        神经网络通常以batch的方式进行训练,geometric在mini-batch实现了并行化,这种组合允许在一个batch中使用不同数量的边和节点。

        在torch_geometric.data.DataLoader中,已经包含了此过程。

        这种mini-batch的操作本质上来说是将一个batch的graph看成是一个大的graph,由此,无论batch size是多少,其将所有的操作都统一在一个大图上进行操作。

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    print(batch,batch.num_graphs)

'''
Batch(edge_index=[2, 3890], x=[1075, 21], y=[32], batch=[1075], ptr=[33]) 32
Batch(edge_index=[2, 4284], x=[1157, 21], y=[32], batch=[1157], ptr=[33]) 32
Batch(edge_index=[2, 4098], x=[1086, 21], y=[32], batch=[1086], ptr=[33]) 32
Batch(edge_index=[2, 3668], x=[916, 21], y=[32], batch=[916], ptr=[33]) 32
Batch(edge_index=[2, 4062], x=[1074, 21], y=[32], batch=[1074], ptr=[33]) 32
Batch(edge_index=[2, 4086], x=[1096, 21], y=[32], batch=[1096], ptr=[33]) 32
Batch(edge_index=[2, 3954], x=[1005, 21], y=[32], batch=[1005], ptr=[33]) 32
Batch(edge_index=[2, 4170], x=[1064, 21], y=[32], batch=[1064], ptr=[33]) 32
Batch(edge_index=[2, 4258], x=[1149, 21], y=[32], batch=[1149], ptr=[33]) 32
Batch(edge_index=[2, 3836], x=[997, 21], y=[32], batch=[997], ptr=[33]) 32
Batch(edge_index=[2, 3886], x=[1016, 21], y=[32], batch=[1016], ptr=[33]) 32
Batch(edge_index=[2, 4066], x=[1042, 21], y=[32], batch=[1042], ptr=[33]) 32
Batch(edge_index=[2, 3946], x=[1046, 21], y=[32], batch=[1046], ptr=[33]) 32
Batch(edge_index=[2, 3656], x=[927, 21], y=[32], batch=[927], ptr=[33]) 32
Batch(edge_index=[2, 4110], x=[1034, 21], y=[32], batch=[1034], ptr=[33]) 32
Batch(edge_index=[2, 3824], x=[1002, 21], y=[32], batch=[1002], ptr=[33]) 32
Batch(edge_index=[2, 4178], x=[1116, 21], y=[32], batch=[1116], ptr=[33]) 32
Batch(edge_index=[2, 3736], x=[974, 21], y=[32], batch=[974], ptr=[33]) 32
Batch(edge_index=[2, 2856], x=[804, 21], y=[24], batch=[804], ptr=[25]) 24
'''

以  Batch(edge_index=[2, 3890], x=[1075, 21], y=[32], batch=[1075], ptr=[33])  为例:

  • edge_index=[2, 3890]——这个batch一共3890条边
  • x=[1075, 21]——整个batch的节点特征矩阵,这个batch一共2075个点,至于这个21,我不太明白,是因为不同的图有不同的特征,所以拼起来一共21个不同的特征吗?欢迎大家在评论区指正!
  • y=[32]——32个图,32维特征
  • batch=[1075]——batch是一个列向量,它将每个节点映射到该batch中的对应的graph:

       

        至于这个ptr,查了很多资料,都没一个说法。

        于是我自己做了一些尝试:感觉可能是这个意思(欢迎指正哈):就是这个batch目前累计看到的图的节点数量

        因为实验是后来补的,所以了不同的图

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
 
dataset = TUDataset(root='', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
 
for batch in loader:
    print(batch,batch.num_graphs)
    break
#Batch(edge_index=[2, 3822], x=[980, 21], y=[32], batch=[980], ptr=[33]) 32

batch['ptr']
'''
tensor([  0,  41,  66,  78, 122, 151, 193, 229, 261, 284, 350, 397, 429, 453,
        493, 534, 576, 588, 605, 634, 644, 660, 693, 723, 763, 811, 834, 847,
        887, 899, 940, 956, 980])
'''

sum=[0]
for i in range(32):
    sum.append(sum[-1]+int(batch[i]['x'].shape[0]))
print(sum)

'''
[0, 41, 66, 78, 122, 151, 193, 229, 261, 284, 350, 397, 429, 453, 493, 534, 576, 588, 605, 634, 644, 660, 693, 723, 763, 811, 834, 847, 887, 899, 940, 956, 980]
'''

2.1 自己的图列表 &DataLoader

不难发现,这种下载的数据集,可以看成是图的集合

那么如果我门自己设计了一些图,集合成一个列表,我们可以直接用这个列表构造DataLoader(注:这里的DataLoader是torch_geometric.loader的DataLoader)

 

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值