import tensorflow as tf
import numpy as np
import gnn.gnn_utils as gnn_utils
data_path = "./data"
set_name = "sub_15_7_200"
# 训练集
inp, arcnode, nodegraph, nodein, labels, _ = gnn_utils.set_load_general(data_path, "train", set_name=set_name)
inp = [a[:, 1:] for a in inp]
# 验证集
inp_val, arcnode_val, nodegraph_val, nodein_val, labels_val, _ = gnn_utils.set_load_general(data_path, "validation",
set_name=set_name)
inp_val = [a[:, 1:] for a in inp_val]
input_dim = len(inp[0][0])
state_dim = 10
output_dim = 2
state_threshold = 0.001
max_iter = 50
tf.compat.v1.disable_eager_execution()
tf.reset_default_graph()
comp_inp = tf.placeholder(tf.float32, shape=(None, input_dim), name="input")