PyG MessagePassing代码注释

PyG MessagePassing代码注释import osimport reimport inspectimport os.path as ospfrom uuid import uuid1from itertools import chainfrom inspect import Parameterfrom typing import List, Optional, Setfrom torch_geometric.typing import Adj, Sizeimport tor
摘要由CSDN通过智能技术生成

PyG MessagePassing代码注释

import os
import re
import inspect
import os.path as osp
from uuid import uuid1
from itertools import chain
from inspect import Parameter
from typing import List, Optional, Set
from torch_geometric.typing import Adj, Size

import torch
from torch import Tensor
from jinja2 import Template
from torch_sparse import SparseTensor
from torch_scatter import gather_csr, scatter, segment_csr

from .utils.helpers import expand_left
from .utils.jit import class_from_module_repr
from .utils.typing import (sanitize, split_types_repr, parse_types,
                           resolve_types)
from .utils.inspector import Inspector, func_header_repr, func_body_repr


class MessagePassing(torch.nn.Module):
    r"""Base class for creating message passing layers of the form

    .. math::
        \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
        \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
        \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),

    where :math:`\square` denotes a differentiable, permutation invariant
    function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
    and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
    MLPs.
    See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
    create_gnn.html>`__ for the accompanying tutorial.

    Args:
        aggr (string, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"` or :obj:`None`).
            (default: :obj:`"add"`)
        flow (string, optional): The flow direction of message passing
            (:obj:`"source_to_target"` or :obj:`"target_to_source"`).
            (default: :obj:`"source_to_target"`)
        node_dim (int, optional): The axis along which to propagate.
            (default: :obj:`-2`)
    """

    special_args: Set[str] = {
   
        'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
        'size_i', 'size_j', 'ptr', 'index', 'dim_size'
    }

    def __init__(self, aggr: Optional[str] = "add",
                 flow: str = "source_to_target", node_dim: int = -2):

        super(MessagePassing, self).__init__()

        self.aggr = aggr
        assert self.aggr in ['add', 'mean', 'max', None]

        self.flow = flow
        assert self.flow in ['source_to_target', 'target_to_source']

        self.node_dim = node_dim

        #+ inspector'调查员',读取这些消息传递的函数中,子类需要用到的函数的全部参数声明, 
        #+  如果某个函数子类没有重写(只是调用了父类中的),就读取父类中该函数的参数声明
        self.inspector = Inspector(self) 
        #. 调查&记录 子类中message函数的参数声明,
        #.   GATconv调用父类MessagePassing中的propagate函数, 父类propagaet调用GATconv中重写的message函数
        #.      因此inspector将 GATconv.message 中声明的参数列表保存下来
        self.inspector.inspect(self.message) 
        #. 父类MessagePassing中的propagate函数, 父类propagaet调用自己的aggregate函数(GATconv中没有重写该函数)
        #.      因此inspector将 MessagePassing.propagate 中声明的参数列表保存下来
        self.inspector.inspect(self.aggregate, pop_first=True) 
        #- 子类未调用message_and_aggregate, 不读取该函数的参数
        self.inspector.inspect(self.message_and_aggregate, pop_first=True)
        #- 子类未调用update, 不读取该函数的参数
        self.inspector.inspect(self.update, pop_first=True) 

        #+ 对于传递和聚合分开进行的方式,
        #+  查找inspector中保存这三个函数的参数声明与special_args中定义的参数的不同者,
        #+  不同者其实就是用户自己额外需要使用到的参数,在子类重写函数中另外定义的参数
        self.__user_args__ = self.inspector.keys( 
            ['message', 'aggregate', 'update']).difference(self.special_args)
        
        #+ 对于传递和聚合合起来进行的方式也是如此
        self.__fused_user_args__ = self.inspector.keys(
            ['message_and_aggregate', 'update']).difference(self.special_args)

        # Support for "fused" message passing.
        self.fuse = self.inspector.implements('message_and_aggregate')

        # Support for GNNExplainer.
        self.__explain__ = False
        self.__edge_mask__ = None

    def __check_input__(self, edge_index, size):
        the_size: List[Optional[int]] = [None, None]

        if isinstance(edge_index, Tensor):
            assert edge_index.dtype == torch.long
            assert edge_index.dim() == 2</
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值