从一组给定位置的点集构建图并生成Data类数据 PyG、networkx

Pytorch_geometric、networkx 从一组给定位置的点集构建图并生成Data类数据

从一组给定位置的点集构建图并生成Data类数据 PyG、networkx

  • 本文的目的是根据已知位置的点数据构建图,即已知图节点,创建邻接矩阵生成边。并将生成的图转化为 PyG框架下可用的Data类型数据。
  • 函数输入为包含点横纵坐标的二维list数组,输出为根据networkx创建的图G,以及图G在PyG下的Data类型数据。

邻接矩阵构造方法

这里参考了ER随机图生成方法,代码链接为ER随机图G(N,p)构造算法的python实现

  1. 设定一个距离阈值D(阈值的选取没有特殊要求,这里为了测试随便设定了一个数值)
  2. 计算当前节点到其他节点之间的距离d,
  3. 如果d<D,在这两个节点之间添加一条边,否则不添加
  4. 重复步骤2、3,遍历所有节点

步骤2 计算两点间的距离使用了jhsignal的函数python计算两点间的距离

Python代码实现

import networkx as nx
import matplotlib.pyplot as plt
import torch
import numpy as np
from torch_geometric.data import Data
import math

# 计算两点之间的距离
def cal_distance(p1, p2):
    return math.sqrt(math.pow((p2[0] - p1[0]), 2) + math.pow((p2[1] - p1[1]), 2))

# 创建邻接矩阵并生成图G
def generateGrap():
    G = nx.Graph()  # 创建空的无向图
    # 设定一个距离阈值
    Distance = 2
    # 初始化一个邻接矩阵
    adjacentMatrix = np.zeros((num_nodes, num_nodes), dtype=int)
    count = 0  # 给生成的边计数
    # 生成一个邻接矩阵,生成边的条件是,从当前点计算与其他点之间的距离,如果距离在阈值内,则在两点之间生成一条边
    for i in range(len(pos)):
        for j in range(i + 1, len(pos)):
            distance = cal_distance(pos[i], pos[j])
            if distance < Distance:
                count = count + 1
                adjacentMatrix[i][j] = adjacentMatrix[j][i] = 1  # 无向图的邻接矩阵是对阵矩阵,对角线是0
    print(count)

    for i in range(len(adjacentMatrix)):
        for j in range(len(adjacentMatrix)):
            if adjacentMatrix[i][j] == 1:  # 如果不加这句将生成完全图,邻接矩阵将不起作用
                G.add_edge(i, j)

    return G

# 转换为PyG中的Data格式数据
def star_data():
    x = torch.eye(num_nodes)  # 生成节点特征,节点个数的单位矩阵
    g = generateGrap()
    num_edges = g.number_of_edges()
    edge_index1 = np.zeros((2, num_edges), dtype=int)  #numpy格式
    for i in range(2):
        for j in range(num_edges):
            edge_index1[i][j] = list(g.edges())[j][i]

    data = Data(x, edge_index=torch.from_numpy(edge_index1)) # 转换成张量,再实例化Data类

    return data


# 主程序开始
# pos代表每个星点的像素坐标,该数据从外部读入
pos = [[-1, 0], [0, 1], [1, 0]]  # 二维list
num_nodes = len(pos)  # 星点的个数(对应为图里的节点个数)
labels = dict(zip([i for i in range(num_nodes)], pos))
position = {k:v for k,v in labels.items()}  # 绘图时节点固定的位置 数据类型为字典

G = generateGrap()   # 从读入的星点坐标生成图G
print(G.edges())     # 打印图的所有边
testx = star_data()  # 生成PyG的Data类型数据
print(testx)

# 画图
nx.draw_networkx(G, pos=position, with_labels=True, node_size=300)
plt.axis('off')  #关闭图框
plt.show()


可视化结果

  • 节点位置固定的无向图
  • 打印图的所有边:[(0, 1), (1, 2)]
  • 打印Data:Data(x=[3, 3], edge_index=[2, 2])
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值