torch_geometric 使用 batch

基于GCNConv, SAGEConv

问题描述

基于torch.geometric的data具有node和edge的特征,其中node的维度表示为[num_nodes, node features],edge的维度表示为[2, edge_features]

torch.geometric的batch不支持CNN中广泛使用的batch维度。CNN中使用batch的数据具有[batch_size, channel, h, w]的维度。如果使用[batch_size, num_nodes, node_features][batch_size, 2, edge_features],代码会有如下报错:

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size XX but got size XXX for tensor number 1 in the list.

解决方案:

解决方案

完整问题

更多细节

在对角线上重复edge_index特征

batch_edge_index = Batch.from_data_list([Data(edge_index=edge_index)]*batch_size)

使用node_dim

conv = GCNConv(in_channels, out_channels, node_dim=1)
conv(x,edge_index)
# 在这种情况下,x(node)的维度是[batch_size, num_nodes, num_features]
# edge_index的维度是(2, edge_features)
# edge_index holds indices < num_nodes

edge_index always need to be two-dimensional, so the mini-batching of [batch_size, num_nodes, num_features] node features is designed to operate on the same underlying graph of shape [2, num_edges]. If your graph is not static across examples, you will have to use PyG’s approach of diagonal adjacency matrix stacking.

⚠️值得注意的是 edge_index中起始点和终点的index都必须小于nodes的个数(因为默认从0开始)

基于 GATConv

问题描述

与GCNConv, SAGEConv不同的是,当使用

x(node)的维度是[batch_size, num_nodes, num_features]
edge_index的维度是(2, edge_features)

会遇到如下报错:Static graphs not supported in GATConv

解决方案

解决方案,可以使用Batch.from_data_list(data_list)将数据转换成Batch格式,或者手动将node concate成[batchsize*num_nodes, num_features],edge_index concate成[2, batch_size*edge_features]
具体维度解析

edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

edge_index_3 = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x_3 = torch.randn(4, 16)

data1= Data(x=x_s,edge_index=edge_index_s)
data2= Data(x=x_t,edge_index=edge_index_t)
data3= Data(x=x_3,edge_index=edge_index_3)
#上面是构建3张Data图对象
# * `Batch(Data)` in case `Data` objects are batched together
#* `Batch(HeteroData)` in case `HeteroData` objects are batched together

data_list = [data1, data2,data3]


loader = Batch.from_data_list(data_list)#调用该函数data_list里的data1、data2、data3 三张图形成一张大图,也就是batch
————————————————
版权声明:本文为CSDN博主「知行合一,至于至善」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_41800917/article/details/120444534
torch_geometric.loader.DataLoader是PyG中的一个类,用于加载和处理数据。它可以将多个批处理成单个巨型,并提供了一些方便的功能。\[2\] 您可以使用torch_geometric.loader.DataLoader来加载和处理数据集。例如,您可以创建一个包含torch_geometric.data.Data对象的常规Python列表,并将其传递给DataLoader来批处理这些数据。\[1\] DataLoader还可以接受一些参数,例如batch_size和shuffle,以控制批处理的大小和数据的顺序。您还可以使用其他可以传递给PyTorch DataLoader的参数,例如num_workers。\[2\] 使用DataLoader加载数据集的示例代码如下:\[3\] ```python from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True) loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: # 在这里对批处理的数据进行处理 # 例如,计算每个的节点维度中的平均节点特征 x = scatter_mean(batch.x, batch.batch, dim=0) print(x.size()) # 输出每个的节点特征的大小 ``` 通过使用torch_geometric.loader.DataLoader,您可以方便地加载和处理数据集。它提供了一种简单而有效的方式来处理大规模的数据。\[3\] #### 引用[.reference_title] - *1* *3* [【PyG】文档总结以及项目经验(持续更新](https://blog.csdn.net/weixin_45928096/article/details/125501673)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [第十九课.Pytorch-geometric扩展](https://blog.csdn.net/qq_40943760/article/details/120265255)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值