Pytorch_geometric、networkx 从一组给定位置的点集构建图并生成Data类数据
从一组给定位置的点集构建图并生成Data类数据 PyG、networkx
- 本文的目的是根据已知位置的点数据构建图,即已知图节点,创建邻接矩阵生成边。并将生成的图转化为 PyG框架下可用的Data类型数据。
- 函数输入为包含点横纵坐标的二维list数组,输出为根据networkx创建的图G,以及图G在PyG下的Data类型数据。
邻接矩阵构造方法
这里参考了ER随机图生成方法,代码链接为ER随机图G(N,p)构造算法的python实现
- 设定一个距离阈值D(阈值的选取没有特殊要求,这里为了测试随便设定了一个数值)
- 计算当前节点到其他节点之间的距离d,
- 如果d<D,在这两个节点之间添加一条边,否则不添加
- 重复步骤2、3,遍历所有节点
步骤2 计算两点间的距离使用了jhsignal的函数python计算两点间的距离
Python代码实现
import networkx as nx
import matplotlib.pyplot as plt
import torch
import numpy as np
from torch_geometric.data import Data
import math
# 计算两点之间的距离
def cal_distance(p1, p2):
return math.sqrt(math.pow((p2[0] - p1[0]), 2) + math.pow((p2[1] - p1[1]), 2))
# 创建邻接矩阵并生成图G
def generateGrap():
G = nx.Graph() # 创建空的无向图
# 设定一个距离阈值
Distance = 2
# 初始化一个邻接矩阵
adjacentMatrix = np.zeros((num_nodes, num_nodes), dtype=int)
count = 0 # 给生成的边计数
# 生成一个邻接矩阵,生成边的条件是,从当前点计算与其他点之间的距离,如果距离在阈值内,则在两点之间生成一条边
for i in range(len(pos)):
for j in range(i + 1, len(pos)):
distance = cal_distance(pos[i], pos[j])
if distance < Distance:
count = count + 1
adjacentMatrix[i][j] = adjacentMatrix[j][i] = 1 # 无向图的邻接矩阵是对阵矩阵,对角线是0
print(count)
for i in range(len(adjacentMatrix)):
for j in range(len(adjacentMatrix)):
if adjacentMatrix[i][j] == 1: # 如果不加这句将生成完全图,邻接矩阵将不起作用
G.add_edge(i, j)
return G
# 转换为PyG中的Data格式数据
def star_data():
x = torch.eye(num_nodes) # 生成节点特征,节点个数的单位矩阵
g = generateGrap()
num_edges = g.number_of_edges()
edge_index1 = np.zeros((2, num_edges), dtype=int) #numpy格式
for i in range(2):
for j in range(num_edges):
edge_index1[i][j] = list(g.edges())[j][i]
data = Data(x, edge_index=torch.from_numpy(edge_index1)) # 转换成张量,再实例化Data类
return data
# 主程序开始
# pos代表每个星点的像素坐标,该数据从外部读入
pos = [[-1, 0], [0, 1], [1, 0]] # 二维list
num_nodes = len(pos) # 星点的个数(对应为图里的节点个数)
labels = dict(zip([i for i in range(num_nodes)], pos))
position = {k:v for k,v in labels.items()} # 绘图时节点固定的位置 数据类型为字典
G = generateGrap() # 从读入的星点坐标生成图G
print(G.edges()) # 打印图的所有边
testx = star_data() # 生成PyG的Data类型数据
print(testx)
# 画图
nx.draw_networkx(G, pos=position, with_labels=True, node_size=300)
plt.axis('off') #关闭图框
plt.show()
可视化结果
- 打印图的所有边:[(0, 1), (1, 2)]
- 打印Data:Data(x=[3, 3], edge_index=[2, 2])