python实现造100组图结构数据p(造SimGNN数据集)
import random
import numpy as np
import torch
from rdkit import Chem
import networkx as nx
import os
import json
def Get_GED(arr1, arr2):
G1 = nx.DiGraph()
G2 = nx.DiGraph()
max1 = -1
max2 = -1
for i in range(0, len(arr1)):
max1 = max(max1, arr1[i][0], arr1[i][1])
for i in range(0, len(arr2)):
max2 = max(max2, arr2[i][0], arr2[i][1])
for i in range(0, max1 + 1):
G1.add_node(i)
for i in range(0, max2 + 1):
G2.add_node(i)
for i in range(0, len(arr1)):
G1.add_edges_from([(arr1[i][0], arr1[i][1])])
G1.add_edges_from([(arr1[i][1], arr1[i][0])])
for i in range(0, len(arr2)):
G2.add_edges_from([(arr2[i][0], arr2[i][1])])
G2.add_edges_from([(arr2[i][1], arr2[i][0])])
# 返回两个图的GED
return nx.graph_edit_distance(G1, G2)
def new_graph(node_num):
max_edge_num = node_num * (node_num - 1)
arr = np.zeros((node_num, node_num))
node_du = np.zeros(node_num, dtype=int)
edge_arr = []
for i in range(0, node_num):
for j in range(0, node_num):
if arr[i][j] == 1 or arr[j][i] == 1:
continue
if i == j:
continue
float = np.random.uniform(0, 1)
if float <= 0.5:
edge_arr.append(i)
edge_arr.append(j)
arr[i][j] = arr[j][i] = 1
tmp=edge_arr
prin_edge_arr = np.array(edge_arr)
for i in prin_edge_arr:
node_du[i] += 1
for i in range(0, len(node_du)):
if node_du[i] == 0:
tmp.append(i)
x=0
while x == i:
x=random.randint(0, len(node_du)-1)
tmp.append(x)
node_du[i] += 1
node_du[x] += 1
prin_edge_arr = np.array(tmp).reshape(-1, 2)
return node_du, prin_edge_arr
def new_ged():
node_num1 = np.random.randint(3, 8)
nodedu1, edge1 = new_graph(node_num1)
node_num2 = np.random.randint(3, 8)
nodedu2, edge2 = new_graph(node_num2)
Ged = Get_GED(edge1, edge2)
return Ged, nodedu1, nodedu2, edge1, edge2
for i in range(0, 100):
path = '../dataset/newtrain/' + str(i) + '.txt'
f = open(path, 'w')
Ged, du1, du2, e1, e2 = new_ged()
dd2=np.array(du2,dtype=str)
dd1 = np.array(du1, dtype=str)
mp = {"graph_1": e1, "ged": Ged, "graph_2": e2, "labels_2": dd2, "labels_1": dd1}
ans = str(mp)
ans = ans.replace("\n", "")
ans = ans.replace(" ", " ")
ans = ans.replace("\'", "\"")
ans = ans.replace("array(", "")
ans = ans.replace(")", "")
ans = ans.replace(", dtype=\"<U11\"", "")
print(ans)
print(ans, file=f)
f.close()
for i in range(0, 100):
pat = '../dataset/newtrain/' + str(i) + '.txt'
filename = os.path.split(pat)
fliesion = filename[1]
newname = str(i) + '.json'
os.rename('../dataset/newtrain/' + fliesion, '../dataset/newtrain/' + newname)