基于矩阵形式的GNN模型实现
数据格式
法一
特别注意(这里使用交叉熵作为模型的损失函数,准确率作为评测结果的指标。这里直接根据loss计算梯度,对变量进行更新,没有像论文中那样计算梯度的稳定状态t,然后求t时刻的梯度)
import tensorflow as tf
import numpy as np
import pandas as pd
import scipy.io as sio
import os
def f_w(inp):
with tf.variable_scope('State_net'):
layer1 = tf.layers.dense(inp, 5, activation=tf.nn.sigmoid)
layer2 = tf.layers.dense(layer1, state_dim, activation=tf.nn.sigmoid)
return layer2
def g_w(inp):
with tf.variable_scope('Output_net'):
layer1 = tf.layers.dense(inp, 5, activation=tf.nn.sigmoid)
layer2 = tf.layers.dense(layer1, output_dim, activation=None)
return layer2
def convergence(a, state, old_state, k):
with tf.variable_scope('Convergence'):
# assign current state to old state
old_state = state
# 获取子结点上一个时刻的状态
# grub states of neighboring node
gat = tf.gather(old_state, tf.cast(a[:, 0], tf.int32))
# 去除第一列,即子结点的id
# slice to consider only label of the node and that of it's neighbor
# sl = tf.slice(a, [0, 1], [tf.shape(a)[0], tf.shape(a)[1] - 1])
# equivalent code
sl = a[:, 1:]
# 将子结点上一个时刻的状态放到最后一列
# concat with retrieved state
inp = tf.concat([sl, gat], axis=1)
# evaluate next state and multiply by the arch-node conversion matrix to obtain per-node states
# 计算子结点对父结点状态的贡献
layer1 = f_w(inp)
# 聚合子结点对父结点状态的贡献,得到当前时刻的父结点的状态
state = tf.sparse_tensor_dense_matmul(ArcNode, layer1)
# update the iteration counter
k = k + 1
return a, state, old_state, k
def condition(a, state, old_state, k):
# evaluate condition on the convergence of the state
with tf.variable_scope('condition'):
# 检查当前状态和上一个时刻的状态的欧式距离是否小于阈值
# evaluate distance by state(t) and state(t-1)
outDistance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(state, old_state)), 1) + 1e-10)
# vector showing item converged or not (given a certain threshold)
checkDistanceVec = tf.greater(outDistance, state_threshold)
c1 = tf.reduce_any(checkDistanceVec)
# 是否达到最大迭代次数
c2 = tf.less(k, max_iter)
return tf.logical_and(c1, c2)
def set_load_general(data_path, set_type, set_name="sub_30_15"):
import load as ld
# load adjacency list
types = ["train", "validation", "test"]
train = ld.loadmat(os.path.join(data_path, "{}.mat".format(set_name)))
train = train["dataSet"]
try:
if set_type not in types:
raise NameError('Wrong set name!')
# load adjacency list
# take adjacency list
adj = coo_matrix(train['{}Set'.format(set_type)]['connMatrix'].T)
adj = np.array([adj.row, adj.col]).T
# take node labels
lab = np.asarray(train['{}Set'.format(set_type)]['nodeLabels']).T
# if clique (labels with only one dimension
if len(lab.shape) < 2:
lab = lab.reshape(lab.shape[0], 1)
# take targets and convert to one-hot encoding
target = np.asarray(train['{}Set'.format(set_type)]['targets']).T
labels = pd.get_dummies(pd.Series(target))
labels = labels.values
# compute inputs and arcnode
inp, arcnode, nodegraph, nodein = GetInput(adj, lab, 1,
np.zeros(len(labels), dtype=int))
return inp, arcnode, nodegraph, nodein, labels, lab
def GetInput(mat, lab, batch=1, grafi=None):
"""grafi is vector with same cardinaluty of nodes, denoting to which graph
belongs each node
"""
# numero di batch
batch_number = grafi.max() // batch # if only one graph => grafi.max() is 0 => batch_number == 0
# dataframe containing adjacency matrix
dmat = pd.DataFrame(mat, columns=["id_1", "id_2"])
# dataframe containing labels each node
dlab = pd.DataFrame(lab, columns=["lab" + str(i) for i in range(0, lab.shape[1])])
# darch=pd.DataFrame(arc, columns=["arch"+str(i) for i in range(0,arc.shape[1])])
# dataframe denoting graph belonging each node
dgr = pd.DataFrame(grafi, columns=["graph"])
# creating input : id_p, id_c, label_p, label_c, graph_belong
dresult = dmat
dresult = pd.merge(dresult, dlab, left_on="id_1", right_index=True, how='left')
dresult = pd.merge(dresult, dlab, left_on="id_2", right_index=True, how='left')
# dresult=pd.concat([dresult, darch], axis=1)
dresult = pd.merge(dresult, dgr, left_on="id_1", right_index=True, how='left')
data_batch = []
arcnode_batch = []
nodegraph_batch = []
node_in = []
# creating batch data => for each batch, redefining the id so that they start from 0 index
for i in range(0, batch_number + 1):
# getting minimum index of the current batch
grafo_indexMin = (i * batch)
grafo_indexMax = (i * batch) + batch
adj = dresult.loc[(dresult["graph"] >= grafo_indexMin) & (dresult[