基于GCNConv, SAGEConv
问题描述
基于torch.geometric的data具有node和edge的特征,其中node的维度表示为[num_nodes, node features]
,edge的维度表示为[2, edge_features]
。
torch.geometric的batch不支持CNN中广泛使用的batch维度。CNN中使用batch的数据具有[batch_size, channel, h, w]
的维度。如果使用[batch_size, num_nodes, node_features]
和[batch_size, 2, edge_features]
,代码会有如下报错:
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size XX but got size XXX for tensor number 1 in the list.
解决方案:
在对角线上重复edge_index特征
batch_edge_index = Batch.from_data_list([Data(edge_index=edge_index)]*batch_size)
使用node_dim
conv = GCNConv(in_channels, out_channels, node_dim=1)
conv(x,edge_index)
# 在这种情况下,x(node)的维度是[batch_size, num_nodes, num_features]
# edge_index的维度是(2, edge_features)
# edge_index holds indices < num_nodes
edge_index
always need to be two-dimensional, so the mini-batching of [batch_size, num_nodes, num_features]
node features is designed to operate on the same underlying graph of shape [2, num_edges]
. If your graph is not static across examples, you will have to use PyG’s approach of diagonal adjacency matrix stacking.
⚠️值得注意的是 edge_index
中起始点和终点的index都必须小于nodes的个数(因为默认从0开始)
基于 GATConv
问题描述
与GCNConv, SAGEConv不同的是,当使用
x(node)的维度是[batch_size, num_nodes, num_features]
edge_index的维度是(2, edge_features)
会遇到如下报错:Static graphs not supported in GATConv
。
解决方案
解决方案,可以使用Batch.from_data_list(data_list)
将数据转换成Batch格式,或者手动将node concate成[batchsize*num_nodes, num_features]
,edge_index concate成[2, batch_size*edge_features]
。
具体维度解析:
edge_index_s = torch.tensor([
[0, 0, 0, 0],
[1, 2, 3, 4],
])
x_s = torch.randn(5, 16) # 5 nodes.
edge_index_t = torch.tensor([
[0, 0, 0],
[1, 2, 3],
])
x_t = torch.randn(4, 16) # 4 nodes.
edge_index_3 = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x_3 = torch.randn(4, 16)
data1= Data(x=x_s,edge_index=edge_index_s)
data2= Data(x=x_t,edge_index=edge_index_t)
data3= Data(x=x_3,edge_index=edge_index_3)
#上面是构建3张Data图对象
# * `Batch(Data)` in case `Data` objects are batched together
#* `Batch(HeteroData)` in case `HeteroData` objects are batched together
data_list = [data1, data2,data3]
loader = Batch.from_data_list(data_list)#调用该函数data_list里的data1、data2、data3 三张图形成一张大图,也就是batch
————————————————
版权声明:本文为CSDN博主「知行合一,至于至善」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_41800917/article/details/120444534