#参数形式:
#row=[0,0,1,1,1,1,2,2,3,3,4]
#col=[1,2,3,4,0,1,2,3,2,3,1]
#pos=[[1,2,3],[1,2,3],[1,2,3],[1,2,3]]
#这上边的三个参数都是tensor类型的,只是我为了方便展示,就直接放在了[]里面
#之所以都是tensor类型的是因为,整个过程我都是在GPU上算的,GPU计算速度比CPU快多了,而且也可以做到在训练的时候不用切换CPU和GPU,
#除此以外,这个代码里面没有对边关系进行遍历,只是单纯的对结点进行了遍历,可以大大缩短运行时间
def fiting_edge_index(row,col,pos):
# time_start = time.time()
# constraint = (torch.cdist(pos, pos) < r).nonzero()
boundary_flag
torch实现,根据距离删除不符合要求的边关系(图是邻接表表示)
最新推荐文章于 2022-10-17 16:13:06 发布