可解释性研究(四)-GNNExplainer的内部实现

之前提到过GNNExplainer的论文,但是论文中的一大堆公式很难让人摸着边,所以上GitHub找了一些GNNExplainer的实现,有

GNNExplainer会从2个角度解释图:

  • 边(edge):会生成一个edge mask,表示每条边在图中出现的概率,值为0-1之间的浮点数。edge mask也可以当作一个权重,可以去topk的edge连成的子图来解释。
  • 结点特征(node feature):node feature(NF)即结点向量,比如一个结点128维表示128个特征,那么它同时会生成一个NF mask来表示每个特征的权重,这个可以不要。

这里以DIG的为基础,贴上一个精简版的explain.py。

import torch
from torch import Tensor
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from math import sqrt

from configuration import data_args

from torch_geometric.data import Batch, Data
from torch.nn.functional import cross_entropy

class ExplainerBase(nn.Module):
    def __init__(self, model: nn.Module, epochs=0, lr=0, explain_graph=False, molecule=False):
        super().__init__()
        self.model = model
        self.lr = lr
        self.epochs = epochs
        self.explain_graph = explain_graph
        self.molecule = molecule
        self.mp_layers = [module for module in self.model.modules() if isinstance(module, MessagePassing)]
        self.num_layers = len(self.mp_layers)

        self.ori_pred = None
        self.ex_labels = None
        self.edge_mask = None
        self.hard_edge_mask = None

        self.num_edges = None
        self.num_nodes = None
        self.device = None


    def __set_masks__(self, x, edge_index, init="normal"):
        (N, F), E = x.size(), edge_index.size(1)

        std = 0.1
        self.node_feat_mask = torch.nn.Parameter(torch.randn(F, requires_grad=True, device=self.device) * 0.1)

        std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
        self.edge_mask = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=self.device) * std)
        # self.edge_mask = torch.nn.Parameter(100 * torch.ones(E, requires_grad=True))

        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = True
                module.__edge_mask__ = self.edge_mask

    def __clear_masks__(self):
        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = False
                module.__edge_mask__ = None
        self.node_feat_masks = None
        self.edge_mask = None

    @property
    def __num_hops__(self):
        if self.explain_graph:
            return -1
        else:
            return self.num_layers

    def __flow__(self):
        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                return module.flow
        return 'source_to_target'



    def forward(self,
                x: Tensor,
                edge_index: Tensor,
                **kwargs
                ):
        self.num_edges = edge_index.shape[1]
        self.num_nodes = x.shape[0]
        self.device = x.device


    def eval_related_pred(self, x, edge_index, edge_masks, **kwargs):

        node_idx = kwargs.get('node_idx')
        node_idx = 0 if node_idx is None else node_idx  # graph level: 0, node level: node_idx
        related_preds = []

        for ex_label, edge_mask in enumerate(edge_masks):

            self.edge_mask.data = float('inf') * torch.ones(edge_mask.size(), device=data_args.device)
            ori_pred = self.model(x=x, edge_index=edge_index, **kwargs)

            self.edge_mask.data = edge_mask
            masked_pred = self.model(x=x, edge_index=edge_index, **kwargs)

            # mask out important elements for fidelity calculation
            self.edge_mask.data = - edge_mask  # keep Parameter's id
            maskout_pred = self.model(x=x, edge_index=edge_index, **kwargs)

            # zero_mask
            self.edge_mask.data = - float('inf') * torch.ones(edge_mask.size(), device=data_args.device)
            zero_mask_pred = self.model(x=x, edge_index=edge_index, **kwargs)

            related_preds.append({'zero': zero_mask_pred[node_idx],
                                  'masked': masked_pred[node_idx],
                                  'maskout': maskout_pred[node_idx],
                                  'origin': ori_pred[node_idx]})


        return related_preds



EPS = 1e-15


class GNNExplainer(ExplainerBase):
    r"""The GNN-Explainer model from the `"GNNExplainer: Generating
    Explanations for Graph Neural Networks"
    <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
    structures and small subsets node features that play a crucial role in a
    GNN’s node-predictions.

    .. note::

        For an example of using GNN-Explainer, see `examples/gnn_explainer.py
        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
        gnn_explainer.py>`_.

    Args:
        model (torch.nn.Module): The GNN module to explain.
        epochs (int, optional): The number of epochs to train.
            (default: :obj:`100`)
        lr (float, optional): The learning rate to apply.
            (default: :obj:`0.01`)
        log (bool, optional): If set to :obj:`False`, will not log any learning
            progress. (default: :obj:`True`)
    """

    coeffs = {
        'edge_size': 0.005,
        'node_feat_size': 1.0,
        'edge_ent': 1.0,
        'node_feat_ent': 0.1,
    }

    def __init__(self, model, epochs=50, lr=0.001, explain_graph=True, molecule=False):
        super(GNNExplainer, self).__init__(model, epochs, lr, explain_graph, molecule)

    def __loss__(self, raw_preds, x_label):
        loss = cross_entropy(raw_preds, x_label)
        m = self.edge_mask.sigmoid()
        loss = loss + self.coeffs['edge_size'] * m.sum()
        ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
        loss = loss + self.coeffs['edge_ent'] * ent.mean()

        if self.mask_features:
            m = self.node_feat_mask.sigmoid()
            loss = loss + self.coeffs['node_feat_size'] * m.sum()
            ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
            loss = loss + self.coeffs['node_feat_ent'] * ent.mean()

        return loss

    def gnn_explainer_alg(self, x: Tensor, edge_index: Tensor, ex_label: Tensor, mask_features: bool = False, **kwargs) -> None:
        # initialize a mask
        patience = 10
        self.to(x.device)
        self.mask_features = mask_features

        # train to get the mask
        optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
                                     lr=self.lr)

        best_loss = 4.0
        count = 0
        for epoch in range(1, self.epochs + 1):
            if mask_features:
                h = x * self.node_feat_mask.view(1, -1).sigmoid()
            else:
                h = x
            raw_preds = self.model(data=Batch.from_data_list([Data(x=h, edge_index=edge_index)]))
            loss = self.__loss__(raw_preds, ex_label)
            # if epoch % 10 == 0:
            #     print(f'#D#Loss:{loss.item()}')

            is_best = (loss < best_loss)

            if not is_best:
                count += 1
            else:
                count = 0
                best_loss = loss

            if count >= patience:
                break

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return self.edge_mask.data

    def forward(self, x, edge_index, mask_features=False,
                positive=True, **kwargs):
        r"""Learns and returns a node feature mask and an edge mask that play a
        crucial role to explain the prediction made by the GNN for node
        :attr:`node_idx`.

        Args:
            data (Batch): batch from dataloader
            edge_index (LongTensor): The edge indices.
            pos_neg (Literal['pos', 'neg']) : get positive or negative mask
            **kwargs (optional): Additional arguments passed to the GNN module.

        :rtype: (:class:`Tensor`, :class:`Tensor`)
        """
        self.model.eval()
        # self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes)
        # Only operate on a k-hop subgraph around `node_idx`.
        # Calculate mask

        ex_label = torch.tensor([1]).to(data_args.device)
        self.__clear_masks__()
        self.__set_masks__(x, edge_index)
        edge_mask = self.gnn_explainer_alg(x, edge_index, ex_label)
        # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label))


        # with torch.no_grad():
        #     related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs)
        self.__clear_masks__()
        sorted_results = edge_mask.sort(descending=True)
        return edge_mask.detach(), sorted_results.indices.cpu(), edge_index.cpu()


    def __repr__(self):
        return f'{self.__class__.__name__}()'

GNNExplainer.forward

入口函数是GNNExplainer.forward,这里我只解释被分类为1的样本。

def forward(self, x, edge_index, mask_features=False,
                positive=True, **kwargs):
        r"""Learns and returns a node feature mask and an edge mask that play a
        crucial role to explain the prediction made by the GNN for node
        :attr:`node_idx`.

        Args:
            data (Batch): batch from dataloader
            edge_index (LongTensor): The edge indices.
            pos_neg (Literal['pos', 'neg']) : get positive or negative mask
            **kwargs (optional): Additional arguments passed to the GNN module.

        :rtype: (:class:`Tensor`, :class:`Tensor`)
        """
        self.model.eval()
        # self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes)
        # Only operate on a k-hop subgraph around `node_idx`.
        # Calculate mask

        ex_label = torch.tensor([1]).to(data_args.device)
        self.__clear_masks__()
        self.__set_masks__(x, edge_index)
        edge_mask = self.gnn_explainer_alg(x, edge_index, ex_label)
        # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label))


        # with torch.no_grad():
        #     related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs)
        self.__clear_masks__()
        sorted_results = edge_mask.sort(descending=True)
        return edge_mask.detach(), sorted_results.indices.cpu(), edge_index.cpu()
  • 函数首先会调用__clear_masks__()清除之前的edge mask(虽然没有),然后调用__set_masks__设定一个初始随机生成的edge mask。

  • 函数会返回edge_mask和根据edge_mask排序后的edge。计算edge mask调用了gnn_explainer_alg

ExplainerBase.set_mask

这个代码了解下就好,主要是设置初始随机生成的edge mask和NF mask。但是我为了简化代码删除了NF mask的设置。完整版可以参考DIG的代码。

def __set_masks__(self, x, edge_index, init="normal"):
        (N, F), E = x.size(), edge_index.size(1)

        std = 0.1
        self.node_feat_mask = torch.nn.Parameter(torch.randn(F, requires_grad=True, device=self.device) * 0.1)

        std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
        self.edge_mask = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=self.device) * std)
        # self.edge_mask = torch.nn.Parameter(100 * torch.ones(E, requires_grad=True))

        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = True
                module.__edge_mask__ = self.edge_mask

GNNExplainer.gnn_explainer_alg

主要计算一个最优的edge mask,NF mask先省略。

def gnn_explainer_alg(self, x: Tensor, edge_index: Tensor, ex_label: Tensor, mask_features: bool = False, **kwargs) -> None:
        # initialize a mask
        patience = 10
        self.to(x.device)
        self.mask_features = mask_features

        # train to get the mask
        optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
                                     lr=self.lr)

        best_loss = 4.0
        count = 0
        for epoch in range(1, self.epochs + 1):
            if mask_features:
                h = x * self.node_feat_mask.view(1, -1).sigmoid()
            else:
                h = x
            raw_preds = self.model(data=Batch.from_data_list([Data(x=h, edge_index=edge_index)]))
            loss = self.__loss__(raw_preds, ex_label)
            # if epoch % 10 == 0:
            #     print(f'#D#Loss:{loss.item()}')

            is_best = (loss < best_loss)

            if not is_best:
                count += 1
            else:
                count = 0
                best_loss = loss

            if count >= patience:
                break

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return self.edge_mask.data

这里将edge mask(还有NF mask)作为可被训练的参数用神经网络训练。loss函数如下

    def __loss__(self, raw_preds, x_label):
        loss = cross_entropy(raw_preds, x_label)
        m = self.edge_mask.sigmoid()
        loss = loss + self.coeffs['edge_size'] * m.sum()
        ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
        loss = loss + self.coeffs['edge_ent'] * ent.mean()

        if self.mask_features:
            m = self.node_feat_mask.sigmoid()
            loss = loss + self.coeffs['node_feat_size'] * m.sum()
            ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
            loss = loss + self.coeffs['node_feat_ent'] * ent.mean()

        return loss

loss中这个raw_preds已经是设置了mask之后的计算结果,因为整个过程都没有再调用clear_mask来清除label。

loss由3部分构成:

  • 添加edge mask之后的模型输出label(不指导换成添加之前的输出会不会更好)之间的交叉熵损失。(这一部分感觉跟模型本身准确率有很大关系,很多GNN分类代码漏洞本身准确率堪忧)。
  • edge_mask本身的大小(size, sum
  • edge_mask离散程度(值越往1, 0靠越好)

整个loss大概可以用如下表达式表示(不加NF mask)

l o s s = C r o s s E n t r o p y ( f ( d a t a , m o d e l , e d g e _ m a s k ) , l a b e l ) + S i z e ( e d g e _ m a s k ) + D i s c r e t e ( e d g e _ m a s k ) loss = CrossEntropy(f(data, model, edge\_mask), label) + Size(edge\_mask) + Discrete(edge\_mask) loss=CrossEntropy(f(data,model,edge_mask),label)+Size(edge_mask)+Discrete(edge_mask)

最后就可以得到一个最优(loss最小)的edge mask。那么就含剩下一个问题用户代码没能回答,就是在 f ( d a t a , m o d e l , e d g e _ m a s k ) f(data,model,edge\_mask) f(data,model,edge_mask) 中 edge mask是如何起作用的。也就是添加edge mask是如何影响model的运算过程的。要回答这个问题,就必须探索torch geometric是如何支持GNNExplainer的。

torch_geometric.nn GNNExplainer

这里的核心类是MessagePassing,其propagate函数不支持GNNExplainer版本的代码如下:

coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs)
            
msg_kwargs = self.inspector.distribute('message', coll_dict)    
out = self.message(**msg_kwargs)
 
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
 
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)

message函数在示例中直接返回了x_jupdate同理直接返回输入tensoraggregate为聚合周围结点的信息。

  • out = self.message(**msg_kwargs)这里返回的out维度为[edge_num, node_feature_dim] 。应该是每条边目标结点的向量表示。
  • out = self.aggregate(out, **aggr_kwargs)这里返回的out维度为[node_num, node_feature_dim]。应该是聚合完毕后每个结点的向量表示。

添加了GNNExplainer支持后的propagate函数代码如下:

coll_dict = self.__collect__(self.__user_args__, edge_index, size,
                                         kwargs)

msg_kwargs = self.inspector.distribute('message', coll_dict)
            out = self.message(**msg_kwargs)

# For `GNNExplainer`, we require a separate message and aggregate
# procedure since this allows us to inject the `edge_mask` into the
# message passing computation scheme.
if self.__explain__:
     edge_mask = self.__edge_mask__.sigmoid()
     # Some ops add self-loops to `edge_index`. We need to do the
     # same for `edge_mask` (but do not train those).
     if out.size(self.node_dim) != edge_mask.size(0):
          loop = edge_mask.new_ones(size[0])
          edge_mask = torch.cat([edge_mask, loop], dim=0)
     assert out.size(self.node_dim) == edge_mask.size(0)
     out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))

aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)

update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)

可以看到在message和aggregate之间插入了一段代码,忽略判断语句以及形状修改,内容就是out = out * edge_mask。也就是对于结点 x i x_i xi,在其与周围结点聚合的同时,先将周围结点(只算入度,就是存在有向边 ( u , v ) (u, v) (u,v) 那么 u u u v v v 的周围结点)的向量乘以edge mask。(这里的理解可能有误,欢迎大佬指正)

  • 5
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 11
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值