PyTorch Geometric的Mini-batches

官方文档 链接

加载ENZYMES数据集

from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader


dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)

loader = DataLoader(dataset, batch_size=4, shuffle=True)

ENZYMES数据集

batch

获取一个batch

batch = loader.__iter__().next()
print(batch)
# Batch(batch=[169], edge_index=[2, 556], ptr=[5], x=[169, 21], y=[4])

由于batch_size=4,所以batch中有4个图。batch的属性如图所示:
在这里插入图片描述

batch.keys
# ['x', 'edge_index', 'y', 'batch', 'ptr']

batch[0].keys
# ['x', 'edge_index', 'y']
取出单个数据
for i in range(batch.num_graphs):
    print(batch[i])
"""
Data(edge_index=[2, 178], x=[50, 21], y=[1])
Data(edge_index=[2, 114], x=[30, 21], y=[1])
Data(edge_index=[2, 160], x=[60, 21], y=[1])
Data(edge_index=[2, 104], x=[29, 21], y=[1])
"""
ptr属性

注意ptr这个属性,如果要把batch中的4个图取出来需要这个属性。
在这里插入图片描述

  • batch[0]就是[0:50] 50-0=50
  • batch[1]就是[50:80] 80-50=30
  • batch[2]就是[80:140] 140-80=60
  • batch[3]就是[140:169] 169-140=29
batch属性

输出batch属性查看一下
在这里插入图片描述
发现连续50个0,30个1,60个2,29个3

batch是怎么区分数据包括哪些的
batch.__slices__
""
{'y': [0, 1, 2, 3, 4], 
'x': [0, 50, 80, 140, 169], 
'edge_index': [0, 178, 292, 452, 556]}
""

获取batch[0]的时候,根据batch.__slices__

  • batch[0]['y'] = batch['y'][ batch.__slices__['y'][0]:batch.__slices__['y'][0+1] ]
  • batch[0]['x'] = batch['x'][ batch.__slices__['x'][0]:batch.__slices__['x'][0+1] ]
  • batch[0]['edge_index'] = batch['edge_index'][ batch.__slices__['edge_index'][0]:batch.__slices__['edge_index'][0+1] ]

获取batch[1]batch[2]、… 、batch[n]的时候,只用将 0 0 0改为相应的下标即可

PyTorch Geometric(PYG)-实现小批量data类中__inc__与__cat_dim__的含义与作用

https://blog.csdn.net/qq_41795143/article/details/114281387

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值