前言:在 PyTorch Geometric 中,edge_index 是一个常用的数据结构。它通常表示为一个形状为 [2, num_edges] 的二维张量
一、 edge_index的基本性质(水点文字)
1、二维张量:edge_index
是一个二维张量,形状为 [2, num_edges]
。第一个维度表示边的起点,第二个维度表示边的终点。
2、零基索引:edge_index
使用零基索引(0-based indexing)
,即节点编号从 0 开始。
3、有向边:edge_index
通常表示有向边,即 edge_index[0, i]
是边的起点,edge_index[1, i]
是边的终点。如果需要无向图,可以通过增加反向边来实现。
4、灵活性:该参数允许灵活定义任意图结构,只要边索引符合上述形状和格式即可。
5、兼容性:edge_index
是 PyTorch Geometric
库中广泛使用的标准数据格式,确保了与该库中其他模块和函数的兼容性。
二、edge_index在数据封装时的自动映射
1、比如蛋白图分子的残基原子如下:每个蛋白质的原子数在400-1200之间不等。如果一次处理一条蛋白质还好,但是当一次处理一个批次时,原子的序号标识便不再唯一,存在大量的重叠,因此 torch_geometric的DATA类的edge_index变量对自动对数据进行映射
tensor([[ 0, 1, 2, ..., 503, 503, 503],
[504, 504, 504, ..., 500, 501, 502]])
tensor([[ 0, 1, 2, ..., 454, 455, 455],
[456, 456, 456, ..., 455, 453, 454]])
tensor([[ 0, 1, 2, ..., 571, 571, 571],
[572, 572, 572, ..., 568, 569, 570]])
tensor([[ 0, 1, 2, ..., 464, 464, 464],
[465, 465, 465, ..., 461, 462, 463]])
tensor([[ 0, 1, 2, ..., 465, 465, 465],
[466, 466, 466, ..., 462, 463, 464]])
tensor([[ 0, 1, 2, ..., 1114, 1114, 1114],
[1115, 1115, 1115, ..., 1111, 1112, 1113]])
tensor([[ 0, 1, 2, ..., 504, 505, 505],
[506, 506, 506, ..., 505, 503, 504]])
tensor([[ 0, 1, 2, ..., 1041, 1042, 1042],
[1043, 1043, 1043, ..., 1042, 1040, 1041]])
tensor([[ 0, 1, 2, ..., 449, 450, 450],
[451, 451, 451, ..., 450, 448, 449]])
tensor([[ 0, 1, 2, ..., 546, 546, 546],
[547, 547, 547, ..., 543, 544, 545]])
2、映射结果如下:只有映射正确才能在后面的
torch.Size([2, 57707]) # 边的总数
tensor([ 0, 1, 2, ..., 6133, 6133, 6133], device='cuda:0') # 映射后的原子序号标识
tensor([ 504, 504, 504, ..., 6130, 6131, 6132], device='cuda:0') # 映射后的原子序号标识
tensor([ 504, 456, 572, 465, 466, 1115, 506, 1043, 451, 547],
device='cuda:0') # 一个批次内每个蛋白质的序列长度
三、不用edge_index这个变量名表示边索引
1、一批次的数据不能一一对应,因为edge_index
的自动映射是使用了节点特征的,根据节点的数量进行映射。使用其他名字导致出现如下报错:
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 5062 but got size 4916 for tensor number 1 in the list.
2、即使通过一些手段(比如交换数据的维度)侥幸“解决”
这个报错,但是没有经过经过获取一批次内节点的总数进行数据映射,存在大量的重叠节点,在下游的操作中必将出现逻辑错误
,如下:显然未经过映射的原子范围很窄
torch.Size([2, 57707])
tensor([ 0, 1, 2, ..., 449, 450, 450], device='cuda:0')
tensor([1043, 1043, 1043, ..., 450, 448, 449], device='cuda:0')
tensor([1043, 465, 466, 456, 572, 1115, 506, 504, 547, 451],
四、总结
在使用边索引时,应当使用torch_geometric的DATA类
提供的edge_index
变量,否则在后期程序用到边索引时,程序必将出现错误,大家可以注意一下