基于矩阵形式的GNN模型实现

基于矩阵形式的GNN模型实现

数据格式
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

法一

特别注意(这里使用交叉熵作为模型的损失函数,准确率作为评测结果的指标。这里直接根据loss计算梯度,对变量进行更新,没有像论文中那样计算梯度的稳定状态t,然后求t时刻的梯度

import tensorflow as tf
import numpy as np
import pandas as pd
import scipy.io as sio
import os

def f_w(inp):
    with tf.variable_scope('State_net'):
        layer1 = tf.layers.dense(inp, 5, activation=tf.nn.sigmoid)
        layer2 = tf.layers.dense(layer1, state_dim, activation=tf.nn.sigmoid)
        return layer2


def g_w(inp):
    with tf.variable_scope('Output_net'):
        layer1 = tf.layers.dense(inp, 5, activation=tf.nn.sigmoid)
        layer2 = tf.layers.dense(layer1, output_dim, activation=None)
        return layer2

def convergence(a, state, old_state, k):
    with tf.variable_scope('Convergence'):
        # assign current state to old state
        old_state = state

        # 获取子结点上一个时刻的状态
        # grub states of neighboring node
        gat = tf.gather(old_state, tf.cast(a[:, 0], tf.int32))

        # 去除第一列,即子结点的id
        # slice to consider only label of the node and that of it's neighbor
        # sl = tf.slice(a, [0, 1], [tf.shape(a)[0], tf.shape(a)[1] - 1])
        # equivalent code
        sl = a[:, 1:]

        # 将子结点上一个时刻的状态放到最后一列
        # concat with retrieved state
        inp = tf.concat([sl, gat], axis=1)

        # evaluate next state and multiply by the arch-node conversion matrix to obtain per-node states
        # 计算子结点对父结点状态的贡献
        layer1 = f_w(inp)
        # 聚合子结点对父结点状态的贡献,得到当前时刻的父结点的状态
        state = tf.sparse_tensor_dense_matmul(ArcNode, layer1)

        # update the iteration counter
        k = k + 1
    return a, state, old_state, k

def condition(a, state, old_state, k):
    # evaluate condition on the convergence of the state
    with tf.variable_scope('condition'):
        # 检查当前状态和上一个时刻的状态的欧式距离是否小于阈值
        # evaluate distance by state(t) and state(t-1)
        outDistance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(state, old_state)), 1) + 1e-10)
        # vector showing item converged or not (given a certain threshold)
        checkDistanceVec = tf.greater(outDistance, state_threshold)

        c1 = tf.reduce_any(checkDistanceVec)

        # 是否达到最大迭代次数
        c2 = tf.less(k, max_iter)
    return tf.logical_and(c1, c2)

def set_load_general(data_path, set_type, set_name="sub_30_15"):
    import load as ld
    # load adjacency list
    types = ["train", "validation", "test"]
    train = ld.loadmat(os.path.join(data_path, "{}.mat".format(set_name)))
    train = train["dataSet"]
    try:
        if set_type not in types:
            raise NameError('Wrong set name!')

        # load adjacency list
        # take adjacency list
        adj = coo_matrix(train['{}Set'.format(set_type)]['connMatrix'].T)
        adj = np.array([adj.row, adj.col]).T

        # take node labels
        lab = np.asarray(train['{}Set'.format(set_type)]['nodeLabels']).T

        # if clique (labels with only one dimension
        if len(lab.shape) < 2:
            lab = lab.reshape(lab.shape[0], 1)

        # take targets and convert to one-hot encoding
        target = np.asarray(train['{}Set'.format(set_type)]['targets']).T
        labels = pd.get_dummies(pd.Series(target))
        labels = labels.values

        # compute inputs and arcnode

        inp, arcnode, nodegraph, nodein = GetInput(adj, lab, 1,
                                                           np.zeros(len(labels), dtype=int))
        return inp, arcnode, nodegraph, nodein, labels, lab

def GetInput(mat, lab, batch=1, grafi=None):
    """grafi is vector with same cardinaluty of nodes, denoting to which graph
        belongs each node
    """
    # numero di batch
    batch_number = grafi.max() // batch   # if only one graph => grafi.max() is 0 => batch_number == 0
    # dataframe containing adjacency matrix
    dmat = pd.DataFrame(mat, columns=["id_1", "id_2"])
    # dataframe containing labels each node
    dlab = pd.DataFrame(lab, columns=["lab" + str(i) for i in range(0, lab.shape[1])])
    # darch=pd.DataFrame(arc, columns=["arch"+str(i) for i in range(0,arc.shape[1])])
    # dataframe denoting graph belonging each node
    dgr = pd.DataFrame(grafi, columns=["graph"])

    # creating input : id_p, id_c, label_p, label_c, graph_belong
    dresult = dmat
    dresult = pd.merge(dresult, dlab, left_on="id_1", right_index=True, how='left')
    dresult = pd.merge(dresult, dlab, left_on="id_2", right_index=True, how='left')
    # dresult=pd.concat([dresult, darch], axis=1)
    dresult = pd.merge(dresult, dgr, left_on="id_1", right_index=True, how='left')

    data_batch = []
    arcnode_batch = []
    nodegraph_batch = []
    node_in = []
    # creating batch data => for each batch, redefining the id so that they start from 0 index
    for i in range(0, batch_number + 1):

        # getting minimum index of the current batch
        grafo_indexMin = (i * batch)
        grafo_indexMax = (i * batch) + batch

        adj = dresult.loc[(dresult["graph"] >= grafo_indexMin) & (dresult[
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值