可解释性研究(二)- XGNN

论文核心

目标

这里作者针对GNN图分类(graph-classification)问题。研究model-level 解释方法,具体方式是训练一个图生成器

f ( . ) f(.) f(.) 表示一个训练好的GNN模型。 y ∈ c 1 , ⋅ ⋅ ⋅ , c ℓ y \in {c_1,···,c_ℓ} yc1,,c 表示图的标签。给定训练好的GNN模型 f ( . ) f(.) f(.) 和标签 c i c_i ci 。图生成器生成能被预测为 c i c_i ci 的图 G ∗ G^* G。定义为

G ∗ = a r g m a x G P ( f ( G ) = c i ) G^* = \mathop{argmax}\limits_{G}P(f(G) = c_i) G=GargmaxP(f(G)=ci)

即最大化 G ∗ G^* G 被预测为 c i c_i ci 的概率。

下图中有4个图均被预测为第3类,人类观测到有一个3角形图是4个图中的共有结构。图生成器最终目标也是生成相似的图,并且引入Graph rules(类似于人工验证)来增强有效性。
在这里插入图片描述

图生成器目标

将图生成器表示为 g θ ( ⋅ ) g_\theta(·) gθ()。作者通过 T T T 个step(时刻)来生成 G ∗ G^* G t t t 时刻生成的图为 G t G_t Gt,包括了

  • n t n_t nt 个结点
  • 特征矩阵 X t ∈ R n t × d X_t \in R^{n_t \times d} XtRnt×d
  • 邻接矩阵 A t ∈ { 0 , 1 } n t × n t A_t \in \{0,1\}^{n_t \times n_t} At{0,1}nt×nt


X t + 1 , A t + 1 = g θ ( X t , A t ) X_{t+1},A_{t+1} = g_\theta(X_t,A_t) Xt+1,At+1=gθ(Xt,At)

生成任务属于强化学习任务。假设数据集中存在 k k k 种类型的结点,定义candidate set C = { s 1 , s 2 , ⋅ ⋅ ⋅ , s k } C = \{s_1,s_2,···,s_k\} C={s1,s2,,sk}。比如化学分子图中结点类型就是原子类型,有 C = { 碳 原 子 , 氢 原 子 , ⋅ ⋅ ⋅ , 氧 原 子 } C = \{碳原子, 氢原子,···, 氧原子\} C={,,,}。社交网络结点没有分类,则candidate set只有1个类型。

g θ ( ⋅ ) g_\theta(·) gθ() 通过学习如何在 G t G_t Gt 中添加边来获得 G t + 1 G_{t+1} Gt+1。可能包括在 G t G_t Gt 中的2个结点中添加一条边或者从candidate set中添加一个结点。

强化学习任务通常包括4个部分:state, action, policy, reward

  • state: t t t 时刻的state就是图 G t G_t Gt,最初时刻的图可以随机从candidate set随机选出一个结点组成。也可以人工选择,比如有机物结构图生成往往选择碳原子作为初始时刻的图。

  • Action: t t t 时刻的action记为 a t a_t at。即基于图 G t G_t Gt 生成 G t + 1 G_{t + 1} Gt+1 的过程。明确说,就是选定一个初始结点和结束结点添加一条边。初始结点 a t , s t a r t a_{t,start} at,start G t G_t Gt 中的结点,而结束结点 a t , e n d a_{t,end} at,end 可以为 G t G_t Gt 或者 C C C 中的一个结点。

  • policy:policy即图生成器 g θ ( ⋅ ) g_\theta(·) gθ() 。可通过reward机制和policy gradient来训练。

  • reward: t t t 时刻的reward表示为 R t R_t Rt。包括2个部分:

    • 来自预训练GNN f ( . ) f(.) f(.) 的guidence(不知道怎么翻译了) ,这个guidence会增加 g θ ( ⋅ ) g_\theta(·) gθ() 生成的图被分类为 c i c_i ci 的概率。并用这个概率作为反馈更新 g θ ( ⋅ ) g_\theta(·) gθ()
    • 促进 g θ ( ⋅ ) g_\theta(·) gθ() 生成的图在graph rules之下是有效的,graph rules包括:社交网络中2个结点不可能有多条边,分子图中原子的度也不会超过它的化学价。

    reward包括了中间奖励和全局奖励。

图生成器

对于 t t t 时刻,action a t a_t at 记为 ( a t , s t a r t , a t , e n d ) (a_{t,start}, a_{t,end}) (at,start,at,end) g θ ( ⋅ ) g_\theta(·) gθ()的目标就是基于 G t G_t Gt C C C 来预测不同action的概率 p t = ( p t , s t a r t , p t , e n d ) p_t=(p_{t,start} ,p_{t,end}) pt=(pt,start,pt,end) g θ ( ⋅ ) g_\theta(·) gθ() 包括了数个GCN。

过程可以描述为
X ^ = G C N s ( G t , C ) \widehat{X} = GCNs(G_t,C) X =GCNs(Gt,C)

p t , s t a r t = S o f t m a x ( M L P s ( X ^ ) ) p_{t,start} = Softmax(MLPs(\widehat{X})) pt,start=Softmax(MLPs(X ))

p t , e n d = S o f t m a x ( M L P s ( [ X ^ , x ^ s t a r t ) ) p_{t,end} = Softmax(MLPs([\widehat{X},\hat x_{start})) pt,end=Softmax(MLPs([X ,x^start))

其中

  • X ^ \widehat{X} X 为GCNs学习到的结点特征
  • a t , s t a r t ∼ p t , s t a r t ⊙ m t , s t a r t a_{t,start} ∼ p_{t,start} \odot m_{t,start} at,startpt,startmt,start m t , s t a r t m_{t,start} mt,start 是一个mask向量用来过滤掉candidate set中的结点, a t , s t a r t a_{t,start} at,start 用来选取 p t , s t a r t p_{t,start} pt,start 中概率最大的结点
  • x ^ s t a r t \hat x_{start} x^start a t , s t a r t a_{t,start} at,start 的特征向量
  • a t , e n d ∼ p t , e n d ⊙ m t , e n d a_{t,end} ∼ p_{t,end} \odot m_{t,end} at,endpt,endmt,end m t , e n d m_{t,end} mt,end 是一个mask向量用来过滤掉结点 a t , s t a r t a_{t,start} at,start

示例如下图所示,Current Graph即 G t G_t Gt。可以看到 G t G_t Gt 包括4个结点, candidate set有3类结点。生成过程包括

  • G t G_t Gt 的特征矩阵 X t X_t Xt C C C 中结点的特征向量拼接(concat),形成特征矩阵 X X X。并且把 G t G_t Gt 的邻接矩阵 A t A_t At 扩展 成 A A A (从 R 4 × 4 R^{4 \times 4} R4×4 扩展成 R 7 × 7 R^{7 \times 7} R7×7
  • 通过GCN形成各个结点的特征向量 X ^ \widehat{X} X (青蓝色的矩阵)
  • X ^ \widehat{X} X 通过第一个MLPs预测新添加边的起始结点 a t , s t a r t a_{t,start} at,start,图中打 × \times × 的结点即被mask的结点。可以看到 C C C 中的结点均被mask
  • X ^ + x ^ s t a r t \widehat{X} + \hat x_{start} X +x^start 通过第二个MLPs预测新边的结束结点 a t , e n d a_{t,end} at,end。可以看到起始结点被mask了。
  • 形成图 G t + 1 G_{t + 1} Gt+1。比 G t G_t Gt 多了一个结点和一个边。
    在这里插入图片描述

训练图生成器

训练 g θ ( ⋅ ) g_\theta(·) gθ() 用到了策略梯度(policy gradient)。公式为
L g = − R t ( L C E ( p t , s t a r t , a t , s t a r t ) + L C E ( p t , e n d , a t , e n d ) ) \mathcal{L}_g = -R_t(\mathcal{L}_{CE}(p_{t,start}, a_{t,start}) + \mathcal{L}_{CE}(p_{t,end}, a_{t,end})) Lg=Rt(LCE(pt,start,at,start)+LCE(pt,end,at,end))
其中

  • L C E \mathcal{L}_{CE} LCE 为交叉熵损失
  • R t R_t Rt t t t 时刻奖励函数(reward function)

R t R_t Rt 包括 R t , f R_{t,f} Rt,f R t , r R_{t,r} Rt,r 2个部分。

R t , f ( G t + 1 ) = p ( f ( G t + 1 ) = c i ) − 1 / ℓ R_{t,f}(G_{t+1}) = p(f(G_{t+1})=ci) − 1 / ℓ Rt,f(Gt+1)=p(f(Gt+1)=ci)1/

R t , f = R t , f ( G t + 1 ) + λ 1 . ∑ i = 0 m R t , f ( R o l l o u t ( G t + 1 ) ) m R_{t,f} = R_{t,f}(G_{t+1}) + \lambda_1 . \frac{\sum_{i = 0}^m R_{t,f}(Rollout(G_{t+1}))}{m} Rt,f=Rt,f(Gt+1)+λ1.mi=0mRt,f(Rollout(Gt+1))

R t = R t , f ( G t + 1 ) + λ 1 . ∑ i = 0 m R t , f ( R o l l o u t ( G t + 1 ) ) m + λ 2 . R t , r R_t = R_{t,f}(G_{t+1}) + \lambda_1 . \frac{\sum_{i = 0}^m R_{t,f}(Rollout(G_{t+1}))}{m} + \lambda_2.R_{t,r} Rt=Rt,f(Gt+1)+λ1.mi=0mRt,f(Rollout(Gt+1))+λ2.Rt,r

其中

  • ℓ ℓ 为图的标签数量
  • λ 1 \lambda_1 λ1 λ 2 \lambda_2 λ2 为超参数
  • R t , r R_{t,r} Rt,r 代表人工制订的graph rules,比如分子图的每个结点都必须满足化学键的规则(必须是合法的有机物),否则 R t , r R_{t,r} Rt,r 会为负

在这里插入图片描述这个算法中最主要的是8,9,10行。

实验

复现实验的代码地址:XGNN-impl

数据集

用到了合成数据集Is_Acyclic和现实数据集MUTAG,这里我用MUTAG来复现下。

MUTAG数据集中图根据它们对细菌的诱变作用分为2类。结点类型包括 Carbon(碳), Nitrogen(氮), Oxygen(氧), Fluorine(氟),Iodine(碘), Chlorine(氯), Bromine(溴) 。 边的类型这里并未用上。

MUTAG中包括188个分子图,总共3371个结点(原子),7442个边(化学键)。数据集目录如下

在这里插入图片描述

  • node_labels.txt记录了3371个结点中每个结点的类型(从0-6编号)
  • graph_indicator.txt记录每个结点对应的图索引号(图索引号从1-188编号)
  • graph_labels.txt记录了188个图每个图对应的类型(label为1或-1)
  • A.txt以(start_node_idx,end_node_idx)记录了7442条边,start_node_idxend_node_idx均在3371范围内
  • edge_labels.txt记录7442条边每个边的类型,这里并未用上。

加载数据集的代码如下:

import numpy as np
import scipy.sparse as sp
import torch

def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def load_split_MUTAG_data(path="datas/MUTAG/", dataset="MUTAG_", split_train=0.7, split_val=0.15):
    """Load MUTAG data """
    print('Loading {} dataset...'.format(dataset))

    # 加载图的标签
    graph_labels = np.genfromtxt("{}{}graph_labels.txt".format(path, dataset),
                           dtype=np.dtype(int))
    graph_labels = encode_onehot(graph_labels)  # (188, 2)
    graph_labels = torch.LongTensor(np.where(graph_labels)[1]) # (188, 1)


    # 图结点的索引号
    graph_idx = np.genfromtxt("{}{}graph_indicator.txt".format(path, dataset),
                              dtype=np.dtype(int))

    graph_idx = np.array(graph_idx, dtype=np.int32)
    idx_map = {j: i for i, j in enumerate(graph_idx)} # key, value表示第key个图的起始结点索引号为value
    length = len(idx_map.keys()) # 总共有多少个图
    num_nodes = [idx_map[n] - idx_map[n - 1] if n - 1 > 1 else idx_map[n] for n in range(1, length + 1)] # 一个长度188的list,表示没个图有多少个结点
    max_num_nodes = max(num_nodes) # 最大的一个图有多少个结点
    features_list = []
    adj_list = []
    prev = 0

    # 结点的标签
    nodeidx_features = np.genfromtxt("{}{}node_labels.txt".format(path, dataset), delimiter=",",
                                     dtype=np.dtype(int))
    node_features = np.zeros((nodeidx_features.shape[0], max(nodeidx_features) + 1))
    node_features[np.arange(nodeidx_features.shape[0]), nodeidx_features] = 1

    # 边信息
    edges_unordered = np.genfromtxt("{}{}A.txt".format(path, dataset), delimiter=",",
                                    dtype=np.int32)

    # 边的标签
    edges_label = np.genfromtxt("{}{}edge_labels.txt".format(path, dataset), delimiter=",",
                                dtype=np.int32)  # shape = (7442,)

    # 生成邻接矩阵A,该邻接矩阵包括了数据集中所有的边
    adj = sp.coo_matrix((edges_label, (edges_unordered[:, 0] - 1, edges_unordered[:, 1] - 1)))

    # 论文里A^=(D~)^0.5 A~ (D~)^0.5这个公式
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

    node_features = normalize(node_features)
    adj = normalize(adj + sp.eye(adj.shape[0])) # 对应公式A~=A+IN
    adj = adj.todense()

    for n in range(1, length + 1):
        # entry为第n个图的特征矩阵X
        entry = np.zeros((max_num_nodes, max(nodeidx_features) + 1))
        entry[:idx_map[n] - prev] = node_features[prev:idx_map[n]]
        entry = torch.FloatTensor(entry)
        features_list.append(entry.tolist())

        # entry为第n个图的邻接矩阵A
        entry = np.zeros((max_num_nodes, max_num_nodes))
        entry[:idx_map[n] - prev, :idx_map[n] - prev] = adj[prev:idx_map[n], prev:idx_map[n]]
        entry = torch.FloatTensor(entry)
        adj_list.append(entry.tolist())

        prev = idx_map[n] # prev为下个图起始结点的索引号

    num_total = max(graph_idx)
    num_train = int(split_train * num_total)
    num_val = int((split_train + split_val) * num_total)

    if (num_train == num_val or num_val == num_total):
        return

    features_list = torch.FloatTensor(features_list)
    adj_list = torch.FloatTensor(adj_list)

    idx_train = range(num_train)
    idx_val = range(num_train, num_val)
    idx_test = range(num_val, num_total)

    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)

    # 返回值一次为 188个图的邻接矩阵列表  188个图的特征矩阵列表  188个图的label, 每个图的起始结点索引号, 训练集索引号,
    # 验证集索引号, 测试集索引号
    return adj_list, features_list, graph_labels, idx_map, idx_train, idx_val, idx_test

这里188个图每个图的邻接矩阵维度为 m a x _ n o d e _ n u m × m a x _ n o d e _ n u m max\_node\_num \times max\_node\_num max_node_num×max_node_num。特征矩阵维度为 m a x _ n o d e _ n u m × f e a t u r e _ d i m max\_node\_num \times feature\_dim max_node_num×feature_dim

训练GCN分类器

这里 f ( . ) f(.) f(.) 用GCN来表示,模型的代码如下

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    paper: Semi-Supervised Classification with Graph Convolutional Networks
    """
    # 模型的参数包括weight和bias
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = Parameter(torch.FloatTensor(out_features))
        self.reset_parameters()

    # 权重初始化
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    # 类似于tostring
    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

    # 计算A~ X W(0)
    def forward(self, input, adj):
        # input.shape = [max_node, features] = X
        # adj.shape = [max_node, max_node] = A~
        # torch.mm(a, b)是矩阵a和b矩阵相乘,torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        return output + self.bias


class GCN(nn.Module):
    # feature的个数;最终的分类数

    def __init__(self, nfeat, nclass, dropout):
        """ As per paper """
        """ 3 layers of GCNs with output dimensions equal to 32, 48, 64 respectively and average all node features """
        """ Final classifier with 2 fully connected layers and hidden dimension set to 32 """
        """ Activation function - ReLu (Mutag) """
        super(GCN, self).__init__()

        self.dropout = dropout

        self.gc1 = GraphConvolution(nfeat, 32)
        self.gc2 = GraphConvolution(32, 48)
        self.gc3 = GraphConvolution(48, 64)
        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, nclass)

    def forward(self, x, adj):
        # x.shape = [max_node, features]
        # adj.shape = [max_node, max_node]
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc2(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc3(x, adj))


        y = torch.mean(x, 0)  # 采用mean作为聚合函数聚合所有结点的特征
        y = F.relu(self.fc1(y))
        y = F.dropout(y, self.dropout, training=self.training)
        y = F.softmax(self.fc2(y), dim=0)

        return y

训练GCN分类器

from Load_dataset import load_split_MUTAG_data, accuracy
from Model import GCN
import time

import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

model_path = 'model/gcn_first.pth'

epochs = 1000
seed = 200
lr = 0.001
dropout = 0.1
weight_decay = 5e-4

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


class EarlyStopping():
    def __init__(self, patience=10, min_loss=0.5, hit_min_before_stopping=False):
        self.patience = patience
        self.counter = 0
        self.hit_min_before_stopping = hit_min_before_stopping
        if hit_min_before_stopping:
            self.min_loss = min_loss
        self.best_loss = None
        self.early_stop = False

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        elif loss > self.best_loss:
            self.counter += 1
            if self.counter > self.patience:
                if self.hit_min_before_stopping == True and loss > self.min_loss:
                    print("Cannot hit mean loss, will continue")
                    self.counter -= self.patience
                else:
                    self.early_stop = True
        else:
            self.best_loss = loss
            counter = 0


if __name__ == '__main__':
    # adj_list: [188, 29, 29]
    # features_list: [188, 29, 7]
    # graph_labels: [188]
    adj_list, features_list, graph_labels, idx_map, idx_train, idx_val, idx_test = load_split_MUTAG_data()
    idx_train = torch.cat([idx_train, idx_val, idx_test])

    model = GCN(nfeat=features_list[0].shape[1], # nfeat = 7
                nclass=graph_labels.max().item() + 1, # nclass = 2
                dropout=dropout)
    optimizer = optim.Adam(model.parameters(),
                           lr=lr, weight_decay=weight_decay)

    model.cuda()
    features_list = features_list.cuda()
    adj_list = adj_list.cuda()
    graph_labels = graph_labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

    # 训练模型
    early_stopping = EarlyStopping(10, hit_min_before_stopping=True)
    t_total = time.time()

    for epoch in range(epochs):
        t = time.time()
        model.train()
        optimizer.zero_grad()

        # # Split
        outputs = []
        for i in idx_train:
            output = model(features_list[i], adj_list[i])
            output = output.unsqueeze(0)
            outputs.append(output)
        output = torch.cat(outputs, dim=0)


        loss_train = F.cross_entropy(output, graph_labels[idx_train])
        acc_train = accuracy(output, graph_labels[idx_train])
        loss_train.backward()
        optimizer.step()

        model.eval()
        outputs = []
        for i in idx_val:
            output = model(features_list[i], adj_list[i])
            output = output.unsqueeze(0)
            outputs.append(output)
        output = torch.cat(outputs, dim=0)
        loss_val = F.cross_entropy(output, graph_labels[idx_val])
        acc_val = accuracy(output, graph_labels[idx_val])

        print('Epoch: {:04d}'.format(epoch + 1),
              'loss_train: {:.4f}'.format(loss_train.item()),
              'acc_train: {:.4f}'.format(acc_train.item()),
              'loss_val: {:.4f}'.format(loss_val.item()),
              'acc_val: {:.4f}'.format(acc_val.item()),
              'time: {:.4f}s'.format(time.time() - t))

        print(loss_val)
        early_stopping(loss_val)
        if early_stopping.early_stop == True:
            break

    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

    torch.save(model.state_dict(), model_path)

训练图生成器

generator的类定义

import random
import copy
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from Model import GraphConvolution, GCN

rollout = 10
max_gen_step = 10
MAX_NUM_NODES = 28 # for mutag
random.seed(200)

class Generator(nn.Module):
    def __init__(self, model_path: str, C: list, node_feature_dim: int ,num_class = 2, c=0, hyp1=1, hyp2=2, start=None, nfeat=7, dropout=0.1):
        """
        :param C: Candidate set of nodes (list)
        :param start: Starting node (defaults to randomised node)
        """
        super(Generator, self).__init__()
        self.nfeat = nfeat
        self.dropout = dropout
        self.c = c

        self.fc = nn.Linear(nfeat, 8)
        self.gc1 = GraphConvolution(8, 16)
        self.gc2 = GraphConvolution(16, 24)
        self.gc3 = GraphConvolution(24, 32)

        # MLP1
        # 2 FC layers with hidden dimension 16
        self.mlp1 = nn.Sequential(nn.Linear(32, 16), nn.Linear(16, 1))

        # MLP2
        # 2 FC layers with hidden dimension 24
        self.mlp2 = nn.Sequential(nn.Linear(64, 24), nn.Linear(24, 1))

        # Hyperparameters
        self.hyp1 = hyp1
        self.hyp2 = hyp2
        self.candidate_set = C

        # Default starting node (if any)
        if start is not None:
            self.start = start
            self.random_start = False
        else:
            self.start = random.choice(np.arange(0, len(self.candidate_set)))
            self.random_start = True

        # Load GCN for calculating reward
        self.model = GCN(nfeat=node_feature_dim,
                         nclass=num_class,
                         dropout=dropout)

        self.model.load_state_dict(torch.load(model_path))
        for param in self.model.parameters():
            param.requires_grad = False

        self.reset_graph()

    def reset_graph(self):
        """
        Reset g.G to default graph with only start node, 生成一个只有1个结点的图
        """
        if self.random_start == True:
            self.start = random.choice(np.arange(0, len(self.candidate_set)))

        # 初始图除了第1个结点全被mask,这里由于邻接矩阵的边长为MAX_NUM_NODES + len(self.candidate_set),所以mask的不仅为候选集结点,还有图中的所以虚结点
        mask_start = torch.BoolTensor(
            [False if i == 0 else True for i in range(MAX_NUM_NODES + len(self.candidate_set))])

        adj = torch.zeros((MAX_NUM_NODES + len(self.candidate_set), MAX_NUM_NODES + len(self.candidate_set)),
                          dtype=torch.float32)   # 这里adj shape为 [MAX_NUM_NODES + len(self.candidate_set), MAX_NUM_NODES + len(self.candidate_set)] 中间可能有空结点

        feat = torch.zeros((MAX_NUM_NODES + len(self.candidate_set), len(self.candidate_set)), dtype=torch.float32)
        feat[0, self.start] = 1
        feat[np.arange(-len(self.candidate_set), 0), np.arange(0, len(self.candidate_set))] = 1

        degrees = torch.zeros(MAX_NUM_NODES)

        self.G = {'adj': adj, 'feat': feat, 'degrees': degrees, 'num_nodes': 1, 'mask_start': mask_start}

    ## 计算Gt->Gt+1
    def forward(self, G_in):
        ## G_in为 Gt
        G = copy.deepcopy(G_in)

        x = G['feat'].detach().clone() # Gt的特征矩阵
        adj = G['adj'].detach().clone() # Gt的邻接矩阵

        ## 对应 X = GCNs(Gt​,C)
        x = F.relu6(self.fc(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu6(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu6(self.gc2(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu6(self.gc3(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)

        ## pt,start​=Softmax(MLPs(X))
        p_start = self.mlp1(x)
        p_start = p_start.masked_fill(G['mask_start'].unsqueeze(1), 0)
        p_start = F.softmax(p_start, dim=0)
        a_start_idx = torch.argmax(p_start.masked_fill(G['mask_start'].unsqueeze(1), -1))

        ## pt,end​=Softmax(MLPs([X,x^start​))
        # broadcast
        x1, x2 = torch.broadcast_tensors(x, x[a_start_idx])
        x = torch.cat((x1, x2), 1)  # cat increases dim from 32 to 64

        # 计算maskt,end,除了候选集和Gt结点中未被选为初始结点的结点之外,其它均被mask
        mask_end = torch.BoolTensor([True for i in range(MAX_NUM_NODES + len(self.candidate_set))])
        mask_end[MAX_NUM_NODES:] = False
        mask_end[:G['num_nodes']] = False
        mask_end[a_start_idx] = True

        p_end = self.mlp2(x)
        p_end = p_end.masked_fill(mask_end.unsqueeze(1), 0)
        p_end = F.softmax(p_end, dim=0)
        a_end_idx = torch.argmax(p_end.masked_fill(mask_end.unsqueeze(1), -1))

        # Return new G
        # If a_end_idx is not masked, node exists in graph, no new node added
        if G['mask_start'][a_end_idx] == False:
            G['adj'][a_end_idx][a_start_idx] += 1
            G['adj'][a_start_idx][a_end_idx] += 1

            # Update degrees
            G['degrees'][a_start_idx] += 1
            G['degrees'][G['num_nodes']] += 1
        else:
            # Add node
            G['feat'][G['num_nodes']] = G['feat'][a_end_idx]
            # Add edge
            G['adj'][G['num_nodes']][a_start_idx] += 1
            G['adj'][a_start_idx][G['num_nodes']] += 1
            # Update degrees
            G['degrees'][a_start_idx] += 1
            G['degrees'][G['num_nodes']] += 1

            # Update start mask
            G_mask_start_copy = G['mask_start'].detach().clone()
            G_mask_start_copy[G['num_nodes']] = False
            G['mask_start'] = G_mask_start_copy

            G['num_nodes'] += 1

        return p_start, a_start_idx, p_end, a_end_idx, G

forward函数为根据 G t G_t Gt 计算 G t + 1 G_{t+1} Gt+1 的过程。这里定义分类任务中一个图最多拥有的结点数 MAX_NUM_NODES = 28。而candidate set C C C 有7个结点。从 G t G_t Gt G t + 1 G_{t+1} Gt+1 邻接矩阵边长均为 MAX_NUM_NODES + len(candidate set) = 35。也就是中间有很多虚结点(类似于padding)。所以mask的时候还要考虑这个。

reward函数定义如下:

### reward函数
    def calculate_reward(self, G_t_1):
        """
        Rtr     Calculated from graph rules to encourage generated graphs to be valid
                1. Only one edge to be added between any two nodes
                2. Generated graph cannot contain more nodes than predefined maximum node number
                3. (For chemical) Degree cannot exceed valency
                If generated graph violates graph rule, Rtr = -1

        Rtf     Feedback from trained model
        """

        rtr = self.check_graph_rules(G_t_1)

        rtf = self.calculate_reward_feedback(G_t_1)
        rtf_sum = 0
        for m in range(rollout):
            p_start, a_start, p_end, a_end, G_t_1 = self.forward(G_t_1)
            rtf_sum += self.calculate_reward_feedback(G_t_1)
        rtf = rtf + rtf_sum * self.hyp1 / rollout

        return rtf + self.hyp2 * rtr

    def calculate_reward_feedback(self, G_t_1):
        """
        p(f(G_t_1) = c) - 1/l
        where l denotes number of possible classes for f
        """
        f = self.model(G_t_1['feat'], G_t_1['adj'], None)
        return f[self.c] - 1 / len(f)


    ## graph rules
    def check_graph_rules(self, G_t_1):
        """
        For mutag, node degrees cannot exceed valency
        """
        idx = 0

        for d in G_t_1['degrees']:
            if d is not 0:
                node_id = torch.argmax(G_t_1['feat'][idx])  # Eg. [0, 1, 0, 0] -> 1
                node = self.candidate_set[node_id]  # Eg ['C.4', 'F.2', 'Br.7'][1] = 'F.2'
                max_valency = int(node.split('.')[1])  # Eg. C.4 -> ['C', '4'] -> 4

                # If any node degree exceeds its valency, return -1
                if max_valency < d:
                    return -1

        return 0

可以看到

  • graph rules只是检测结点的度是否超过其原子化学价,不合法返回-1,合法返回0

loss为

## 计算loss
    def calculate_loss(self, Rt, p_start, a_start, p_end, a_end, G_t_1):
        """
        Calculated from cross entropy loss (Lce) and reward function (Rt)
        where loss = -Rt*(Lce_start + Lce_end)
        """

        Lce_start = F.cross_entropy(torch.reshape(p_start, (1, 35)), a_start.unsqueeze(0))
        Lce_end = F.cross_entropy(torch.reshape(p_end, (1, 35)), a_end.unsqueeze(0))

        return -Rt * (Lce_start + Lce_end)
  • 35为MAX_NUM_NODES + len(candidate set) = 35

这里reward和loss都是Generator类的成员函数。

训练代码

from GraphGenerator import Generator
import copy

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

import torch
import torch.optim as optim

lr = 0.01
b1 = 0.9
b2 = 0.99
hyp1 = 1
hyp2 = 2
max_gen_step = 10  # T = 10

candidate_set = ['C.4', 'N.5', 'O.2', 'F.1', 'I.7', 'Cl.7', 'Br.5']  # C.4表明碳原子的度不超过4
model_path = 'model/gcn_first.pth'

## 训练generator
def train_generator(c=0, max_nodes=5):
    g.c = c
    for i in range(max_gen_step):
        optimizer.zero_grad()
        G = copy.deepcopy(g.G)
        p_start, a_start, p_end, a_end, G = g.forward(G)

        Rt = g.calculate_reward(G)
        loss = g.calculate_loss(Rt, p_start, a_start, p_end, a_end, G)
        loss.backward()
        optimizer.step()

        if G['num_nodes'] > max_nodes:
            g.reset_graph()
        elif Rt > 0:
            g.G = G


## 生成图
def generate_graph(c=0, max_nodes=5):
    g.c = c
    g.reset_graph()

    for i in range(max_gen_step):
        G = copy.deepcopy(g.G)
        p_start, a_start, p_end, a_end, G = g.forward(G)
        Rt = g.calculate_reward(G)

        if G['num_nodes'] > max_nodes:
            return g.G
        elif Rt > 0:
            g.G = G

    return g.G

## 画图
def display_graph(G):
    G_nx = nx.from_numpy_matrix(np.asmatrix(G['adj'][:G['num_nodes'], :G['num_nodes']].numpy()))
    # nx.draw_networkx(G_nx)

    layout=nx.spring_layout(G_nx)
    nx.draw(G_nx, layout)

    coloring=torch.argmax(G['feat'],1)
    colors=['b','g','r','c','m','y','k']

    for i in range(7):
        nx.draw_networkx_nodes(G_nx,pos=layout,nodelist=[x for x in G_nx.nodes() if coloring[x]==i],node_color=colors[i])
        nx.draw_networkx_labels(G_nx,pos=layout,labels={x:candidate_set[i].split('.')[0] for x in G_nx.nodes() if coloring[x]==i})
    nx.draw_networkx_edges(G_nx,pos=layout,width=list(nx.get_edge_attributes(G_nx,'weight').values()))
    nx.draw_networkx_edge_labels(G_nx,pos=layout,edge_labels=nx.get_edge_attributes(G_nx, "weight"))

    plt.show()

if __name__ == '__main__':
    g = Generator(model_path = model_path, C = candidate_set, node_feature_dim=7 ,c=0, start=0)
    optimizer = optim.Adam(g.parameters(), lr=lr, betas=(b1, b2))

    for i in range(1, 10):
        ## 生成最多分别包括i个结点的图结构
        g.reset_graph()
        train_generator(c=1, max_nodes=i)
        to_display = generate_graph(c=1, max_nodes=i)
        display_graph(to_display)
        print(g.model(to_display['feat'], to_display['adj']))

这里的训练过程不能用数据评估只能画出来, 这里分别生成包含 1-9 个结点的能够被GCN分类模型 f ( . ) f(.) f(.) 预测为1的子图结构并给出其为1的概率。结果如下

1 概率:0.7715

在这里插入图片描述

2 概率:0.7935

在这里插入图片描述

3 概率:0.8358

在这里插入图片描述

4 概率:0.8556

在这里插入图片描述

5 概率:0.8778

在这里插入图片描述

6 概率:0.8533

在这里插入图片描述

7 概率:0.9010

在这里插入图片描述

8 概率:0.9005

在这里插入图片描述

9 概率:0.8510

在这里插入图片描述
和论文的差距还是挺明显的,调参还是挺有学问的,可能还是我太菜了掌握不了。

参考文献

[1] H. Yuan, J. Tang, X. Hu, and S. Ji, “XGNN: Towards model-levelexplanations of graph neural networks,” ser. KDD ’20. New York,NY, USA: Association for Computing Machinery, 2020, p. 430–438.[Online]. Available: https://doi.org/10.1145/3394486.3403085

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
随着深度学习技术的不断发展,图神经网络在海空目标行为认知领域得到了广泛应用。可解释图神经网络(Explainable Graph Neural Networks,XGNN)在图神经网络的基础上,加入了可解释性的机制,可以提供更加直观且可解释的结果。 海空目标行为认知技术研究现状主要包括以下几个方面: 1. 基于图神经网络的海空目标行为识别 图神经网络可以对复杂的海空目标行为进行识别,通过学习目标之间的关系,实现对行为的分类和识别。目前主要采用图卷积神经网络(Graph Convolutional Networks,GCN)和图注意力网络(Graph Attention Networks,GAT)等方法实现海空目标行为识别。 2. 基于可解释图神经网络的海空目标行为解释 可解释图神经网络可以提供更加直观且可解释的结果,通过可视化的方式呈现图中节点和边的重要程度。可以通过解释网络的输出结果,理解网络是如何对海空目标行为进行分类和识别的。 3. 基于可解释图神经网络的海空目标行为预测 可解释图神经网络可以对海空目标的未来行为进行预测,通过学习目标之间的动态关系,实现对未来行为的预测。可以通过可视化的方式呈现预测结果,解释网络是如何进行预测的。 4. 基于可解释图神经网络的海空目标行为关系分析 可解释图神经网络可以对海空目标之间的关系进行分析,通过学习目标之间的关系,实现对关系的分类和识别。可以通过可视化的方式呈现关系的重要程度,解释网络是如何对关系进行分类和识别的。 综上所述,基于可解释图神经网络的海空目标行为认知技术可以提供更加直观且可解释的结果,具有重要的应用价值。未来研究可以进一步探索可解释图神经网络在海空目标行为认知领域的应用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值