graph_nets(deepmind简单的一些注释)

(1)blocks.py

# Copyright 2018 The GraphNets Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Building blocks for Graph Networks.

This module contains elementary building blocks of graph networks:

  - `broadcast_{field_1}_to_{field_2}` propagates the features from `field_1`
    onto the relevant elements of `field_2`;

  - `{field_1}To{field_2}Aggregator` propagates and then reduces the features
    from `field_1` onto the relevant elements of `field_2`;

  - the `EdgeBlock`, `NodeBlock` and `GlobalBlock` are elementary graph networks
    that only update the edges (resp. the nodes, the globals) of their input
    graph (as described in https://arxiv.org/abs/1806.01261).
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from graph_nets import _base
from graph_nets import graphs
from graph_nets import utils_tf

import tensorflow as tf


NODES = graphs.NODES
EDGES = graphs.EDGES
GLOBALS = graphs.GLOBALS
RECEIVERS = graphs.RECEIVERS
SENDERS = graphs.SENDERS
GLOBALS = graphs.GLOBALS
N_NODE = graphs.N_NODE
N_EDGE = graphs.N_EDGE


def _validate_graph(graph, mandatory_fields, additional_message=None):
  for field in mandatory_fields:
    if getattr(graph, field) is None:
      message = "`{}` field cannot be None".format(field)
      if additional_message:
        message += " " + format(additional_message)
      message += "."
      raise ValueError(message)


def _validate_broadcasted_graph(graph, from_field, to_field):
  additional_message = "when broadcasting {} to {}".format(from_field, to_field)
  _validate_graph(graph, [from_field, to_field], additional_message)


def broadcast_globals_to_edges(graph, name="broadcast_globals_to_edges"): # 每个图的全局变量复制该图中边的总个数 最后再0维度上拼接
  """Broadcasts the global features to the edges of a graph.

  Args:
    graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of
      shape `[n_graphs] + global_shape`, and `N_EDGE` field of shape
      `[n_graphs]`.
    name: (string, optional) A name for the operation.

  Returns:
    A tensor of shape `[n_edges] + global_shape`, where
    `n_edges = sum(graph.n_edge)`. The i-th element of this tensor is given by
    `globals[j]`, where j is the index of the graph the i-th edge belongs to
    (i.e. is such that
    `sum_{k < j} graphs.n_edge[k] <= i < sum_{k <= j} graphs.n_edge[k]`).

  Raises:
    ValueError: If either `graph.globals` or `graph.n_edge` is `None`.
  """
  _validate_broadcasted_graph(graph, GLOBALS, N_EDGE) # ValueError: If either `graph.globals` or `graph.n_edge` is `None`
  with tf.name_scope(name): # 给变量添加前缀名 
    return utils_tf.repeat(graph.globals, graph.n_edge, axis=0) # 每个图的全局变量复制该图中边的总个数 最后再0维度上拼接


def broadcast_globals_to_nodes(graph, name="broadcast_globals_to_nodes"): # 每个图的全局变量复制该图中结点的总个数 最后再0维度上拼接
  """Broadcasts the global features to the nodes of a graph.

  Args:
    graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of
      shape `[n_graphs] + global_shape`, and `N_NODE` field of shape
      `[n_graphs]`.
    name: (string, optional) A name for the operation.

  Returns:
    A tensor of shape `[n_nodes] + global_shape`, where
    `n_nodes = sum(graph.n_node)`. The i-th element of this tensor is given by
    `globals[j]`, where j is the index of the graph the i-th node belongs to
    (i.e. is such that
    `sum_{k < j} graphs.n_node[k] <= i < sum_{k <= j} graphs.n_node[k]`).

  Raises:
    ValueError: If either `graph.globals` or `graph.n_node` is `None`.
  """
  _validate_broadcasted_graph(graph, GLOBALS, N_NODE) # ValueError: If either `graph.globals` or `graph.n_node` is `None`
  with tf.name_scope(name): # 给变量添加前缀名 
    return utils_tf.repeat(graph.globals, graph.n_node, axis=0) # 每个图的全局变量复制该图中结点的总个数 最后再0维度上拼接


def broadcast_sender_nodes_to_edges( # 存储每条边的发送结点的特征信息
    graph, name="broadcast_sender_nodes_to_edges"):
  """Broadcasts the node features to the edges they are sending into.

  Args:
    graph: A `graphs.GraphsTuple` containing `Tensor`s, with nodes features of
      shape `[n_nodes] + node_shape`, and `senders` field of shape
      `[n_edges]`.
    name: (string, optional) A name for the operation.

  Returns:
    A tensor of shape `[n_edges] + node_shape`, where
    `n_edges = sum(graph.n_edge)`. The i-th element is given by
    `graph.nodes[graph.senders[i]]`.

  Raises:
    ValueError: If either `graph.nodes` or `graph.senders` is `None`.
  """
  _validate_broadcasted_graph(graph, NODES, SENDERS) # ValueError: If either `graph.nodes` or `graph.senders` is `None`
  with tf.name_scope(name): # 给变量添加前缀名 
    return tf.gather(graph.nodes, graph.senders) # 返回的是每条边的发送结点的结点特征向量


def broadcast_receiver_nodes_to_edges( # 返回的是每条边的接收结点的结点特征向量
    graph, name="broadcast_receiver_nodes_to_edges"):
  """Broadcasts the node features to the edges they are receiving from.

  Args:
    graph: A `graphs.GraphsTuple` containing `Tensor`s, with nodes features of
      shape `[n_nodes] + node_shape`, and receivers of shape `[n_edges]`.
    name: (string, optional) A name for the operation.

  Returns:
    A tensor of shape `[n_edges] + node_shape`, where
    `n_edges = sum(graph.n_edge)`. The i-th element is given by
    `graph.nodes[graph.receivers[i]]`.

  Raises:
    ValueError: If either `graph.nodes` or `graph.receivers` is `None`.
  """
  _validate_broadcasted_graph(graph, NODES, RECEIVERS) # ValueError: If either `graph.nodes` or `graph.receivers` is `None`
  with tf.name_scope(name): # 给变量添加前缀名
    return tf.gather(graph.nodes, graph.receivers) # 返回的是每条边的接收结点的结点特征向量


class EdgesToGlobalsAggregator(_base.AbstractModule): # 将每个图的对应所有边信息相加 返回的第一维度是图的个数
  """Aggregates all edges into globals."""

  def __init__(self, reducer, name="edges_to_globals_aggregator"):
    """Initializes the EdgesToGlobalsAggregator module.

    The reducer is used for combining per-edge features (one set of edge
    feature vectors per graph) to give per-graph features (one feature
    vector per graph). The reducer should take a `Tensor` of edge features, a
    `Tensor` of segment indices, and a number of graphs. It should be invariant
    under permutation of edge features within each graph.

    Examples of compatible reducers are:
    * tf.math.unsorted_segment_sum
    * tf.math.unsorted_segment_mean
    * tf.math.unsorted_segment_prod
    * unsorted_segment_min_or_zero
    * unsorted_segment_max_or_zero

    Args:
      reducer: A function for reducing sets of per-edge features to individual
        per-graph features.
      name: The module name.
    """
    super(EdgesToGlobalsAggregator, self).__init__(name=name)
    self._reducer = reducer

  def _build(self, graph):
    _validate_graph(graph, (EDGES,), # 验证所需信息的依赖是否满足
                    additional_message="when aggregating from edges.")
    num_graphs = utils_tf.get_num_graphs(graph) # 返回图中图的个数
    graph_index = tf.range(num_graphs) # 生成列表 tensor
    indices = utils_tf.repeat(graph_index, graph.n_edge, axis=0) # 每个图序号重复该图中边的总数 在一维展开
    return self._reducer(graph.edges, indices, num_graphs) # 将每个图的对应所有边信息相加 返回的第一维度是图的个数


class NodesToGlobalsAggregator(_base.AbstractModule): # 将每个图的对应所有点信息相加 返回的第0维度是图的个数
  """Aggregates all nodes into globals."""

  def __init__(self, reducer, name="nodes_to_globals_aggregator"):
    """Initializes the NodesToGlobalsAggregator module.

    The reducer is used for combining per-node features (one set of node
    feature vectors per graph) to give per-graph features (one feature
    vector per graph). The reducer should take a `Tensor` of node features, a
    `Tensor` of segment indices, and a number of graphs. It should be invariant
    under permutation of node features within each graph.

    Examples of compatible reducers are:
    * tf.math.unsorted_segment_sum
    * tf.math.unsorted_segment_mean
    * tf.math.unsorted_segment_prod
    * unsorted_segment_min_or_zero
    * unsorted_segment_max_or_zero

    Args:
      reducer: A function for reducing sets of per-node features to individual
        per-graph features.
      name: The module name.
    """
    super(NodesToGlobalsAggregator, self).__init__(name=name)
    self._reducer = reducer # 集成信息方法

  def _build(self, graph):
    _validate_graph(graph, (NODES,),
                    additional_message="when aggregating from nodes.")
    num_graphs = utils_tf.get_num_graphs(graph) # 返回图中图的个数
    graph_index = tf.range(num_graphs) # 生成列表 tensor
    indices = utils_tf.repeat(graph_index, graph.n_node, axis=0) # 每个图序号重复该图中边的总数 在一维展开
    return self._reducer(graph.nodes, indices, num_graphs) # 将每个图的对应所有点信息相加 返回的第0维度是图的个数


class _EdgesToNodesAggregator(_base.AbstractModule): # 对于将边的信息集成起来 流动到结点中
  """Agregates sent or received edges into the corresponding nodes."""

  def __init__(self, reducer, use_sent_edges=False,
               name="edges_to_nodes_aggregator"):
    super(_EdgesToNodesAggregator, self).__init__(name=name)
    self._reducer = reducer # 集成方法 注意自定义的函数需要满足一定的参数限制
    self._use_sent_edges = use_sent_edges # 是否使用以该结点为发送点的边

  def _build(self, graph):
    _validate_graph(graph, (EDGES, SENDERS, RECEIVERS,), # 验证各参数是否满足相互依赖关系
                    additional_message="when aggregating from edges.")
    # If the number of nodes are known at graph construction time (based on the
    # shape) then use that value to make the model compatible with XLA/TPU.
    if graph.nodes is not None and graph.nodes.shape.as_list()[0] is not None:
      num_nodes = graph.nodes.shape.as_list()[0] # 所有图的结点总数
    else:
      num_nodes = tf.reduce_sum(graph.n_node) # 所有图的结点总数
    indices = graph.senders if self._use_sent_edges else graph.receivers # 判断使用接收结点序号还是发送结点序号
    return self._reducer(graph.edges, indices, num_nodes) # 将每个作为特定身份的结点的边抽取出来 信息求和 返回的第0维度长度是总的结点数


class SentEdgesToNodesAggregator(_EdgesToNodesAggregator): # 以发送结点为索引
  """Agregates sent edges into the corresponding sender nodes."""

  def __init__(self, reducer, name="sent_edges_to_nodes_aggregator"):
    """Constructor.

    The reducer is used for combining per-edge features (one set of edge
    feature vectors per node) to give per-node features (one feature
    vector per node). The reducer should take a `Tensor` of edge features, a
    `Tensor` of segment indices, and a number of nodes. It should be invariant
    under permutation of edge features within each segment.

    Examples of compatible reducers are:
    * tf.math.unsorted_segment_sum
    * tf.math.unsorted_segment_mean
    * tf.math.unsorted_segment_prod
    * unsorted_segment_min_or_zero
    * unsorted_segment_max_or_zero

    Args:
      reducer: A function for reducing sets of per-edge features to individual
        per-node features.
      name: The module name.
    """
    super(SentEdgesToNodesAggregator, self).__init__(
        use_sent_edges=True,
        reducer=reducer,
        name=name)


class ReceivedEdgesToNodesAggregator(_EdgesToNodesAggregator): # 以接收结点为索引
  """Agregates received edges into the corresponding receiver nodes."""

  def __init__(self, reducer, name="received_edges_to_nodes_aggregator"):
    """Constructor.

    The reducer is used for combining per-edge features (one set of edge
    feature vectors per node) to give per-node features (one feature
    vector per node). The reducer should take a `Tensor` of edge features, a
    `Tensor` of segment indices, and a number of nodes. It should be invariant
    under permutation of edge features within each segment.

    Examples of compatible reducers are:
    * tf.math.unsorted_segment_sum
    * tf.math.unsorted_segment_mean
    * tf.math.unsorted_segment_prod
    * unsorted_segment_min_or_zero
    * unsorted_segment_max_or_zero

    Args:
      reducer: A function for reducing sets of per-edge features to individual
        per-node features.
      name: The module name.
    """
    super(ReceivedEdgesToNodesAggregator, self).__init__(
        use_sent_edges=False, reducer=reducer, name=name)


def _unsorted_segment_reduction_or_zero(reducer, values, indices, num_groups):
  """Common code for unsorted_segment_{min,max}_or_zero (below)."""
  reduced = reducer(values, indices, num_groups)
  present_indices = tf.math.unsorted_segment_max(
      tf.ones_like(indices, dtype=reduced.dtype), indices, num_groups)
  present_indices = tf.clip_by_value(present_indices, 0, 1)
  present_indices = tf.reshape(
      present_indices, [num_groups] + [1] * (reduced.shape.ndims - 1))
  reduced *= present_indices
  return reduced


def unsorted_segment_min_or_zero(values, indices, num_groups,
                                 name="unsorted_segment_min_or_zero"):
  """Aggregates information using elementwise min.

  Segments with no elements are given a "min" of zero instead of the most
  positive finite value possible (which is what `tf.math.unsorted_segment_min`
  would do).

  Args:
    values: A `Tensor` of per-element features.
    indices: A 1-D `Tensor` whose length is equal to `values`' first dimension.
    num_groups: A `Tensor`.
    name: (string, optional) A name for the operation.

  Returns:
    A `Tensor` of the same type as `values`.
  """
  with tf.name_scope(name):
    return _unsorted_segment_reduction_or_zero(
        tf.math.unsorted_segment_min, values, indices, num_groups)


def unsorted_segment_max_or_zero(values, indices, num_groups,
                                 name="unsorted_segment_max_or_zero"):
  """Aggregates information using elementwise max.

  Segments with no elements are given a "max" of zero instead of the most
  negative finite value possible (which is what `tf.math.unsorted_segment_max`
  would do).

  Args:
    values: A `Tensor` of per-element features.
    indices: A 1-D `Tensor` whose length is equal to `values`' first dimension.
    num_groups: A `Tensor`.
    name: (string, optional) A name for the operation.

  Returns:
    A `Tensor` of the same type as `values`.
  """
  with tf.name_scope(name):
    return _unsorted_segment_reduction_or_zero(
        tf.math.unsorted_segment_max, values, indices, num_groups)


class EdgeBlock(_base.AbstractModule): # 边模型
  """Edge block.

  A block that updates the features of each edge in a batch of graphs based on
  (a subset of) the previous edge features, the features of the adjacent nodes,
  and the global features of the corresponding graph.

  See https://arxiv.org/abs/1806.01261 for more details.
  """

  def __init__(self,
               edge_model_fn, # sonnet 模型
               use_edges=True, # 各种参数
               use_receiver_nodes=True,
               use_sender_nodes=True,
               use_globals=True,
               name="edge_block"):
    """Initializes the EdgeBlock module.

    Args:
      edge_model_fn: A callable that will be called in the variable scope of
        this EdgeBlock and should return a Sonnet module (or equivalent
        callable) to be used as the edge model. The returned module should take
        a `Tensor` (of concatenated input features for each edge) and return a
        `Tensor` (of output features for each edge). Typically, this module
        would input and output `Tensor`s of rank 2, but it may also be input or
        output larger ranks. See the `_build` method documentation for more
        details on the acceptable inputs to this module in that case.
      use_edges: (bool, default=True). Whether to condition on edge attributes.
      use_receiver_nodes: (bool, default=True). Whether to condition on receiver
        node attributes.
      use_sender_nodes: (bool, default=True). Whether to condition on sender
        node attributes.
      use_globals: (bool, default=True). Whether to condition on global
        attributes.
      name: The module name.

    Raises:
      ValueError: When fields that are required are missing.
    """
    super(EdgeBlock, self).__init__(name=name)

    if not (use_edges or use_sender_nodes or use_receiver_nodes or use_globals):
      raise ValueError("At least one of use_edges, use_sender_nodes, "
                       "use_receiver_nodes or use_globals must be True.")
    # 赋值参数
    self._use_edges = use_edges
    self._use_receiver_nodes = use_receiver_nodes
    self._use_sender_nodes = use_sender_nodes
    self._use_globals = use_globals

    with self._enter_variable_scope():
      self._edge_model = edge_model_fn()

  def _build(self, graph):
    """Connects the edge block.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s, whose individual edges
        features (if `use_edges` is `True`), individual nodes features (if
        `use_receiver_nodes` or `use_sender_nodes` is `True`) and per graph
        globals (if `use_globals` is `True`) should be concatenable on the last
        axis.

    Returns:
      An output `graphs.GraphsTuple` with updated edges.

    Raises:
      ValueError: If `graph` does not have non-`None` receivers and senders, or
        if `graph` has `None` fields incompatible with the selected `use_edges`,
        `use_receiver_nodes`, `use_sender_nodes`, or `use_globals` options.
    """
    # 验证图的各种参数是否满足依赖关系
    _validate_graph(
        graph, (SENDERS, RECEIVERS, N_EDGE), " when using an EdgeBlock")

    edges_to_collect = []

    if self._use_edges:
      _validate_graph(graph, (EDGES,), "when use_edges == True")
      edges_to_collect.append(graph.edges) # 存储graph.edges

    if self._use_receiver_nodes:
      edges_to_collect.append(broadcast_receiver_nodes_to_edges(graph)) # 存储每条边的接收结点的特征信息

    if self._use_sender_nodes:
      edges_to_collect.append(broadcast_sender_nodes_to_edges(graph)) # 存储每条边的发送结点的特征信息

    if self._use_globals:
      edges_to_collect.append(broadcast_globals_to_edges(graph)) # 存储全局变量到边的信息

    collected_edges = tf.concat(edges_to_collect, axis=-1) # 因为第0维度长度为所有图中边的总数 所以在最后一个维度拼接
    updated_edges = self._edge_model(collected_edges) # 使用边模型 生成对应的边参数
    return graph.replace(edges=updated_edges) # 更新边的特征信息


class NodeBlock(_base.AbstractModule): # 点模型
  """Node block.

  A block that updates the features of each node in batch of graphs based on
  (a subset of) the previous node features, the aggregated features of the
  adjacent edges, and the global features of the corresponding graph.

  See https://arxiv.org/abs/1806.01261 for more details.
  """

  def __init__(self,
               node_model_fn, # sonnet 模型
               use_received_edges=True, # 各种参数
               use_sent_edges=False,
               use_nodes=True,
               use_globals=True,
               received_edges_reducer=tf.math.unsorted_segment_sum, # 按相同序号抽取相加
               sent_edges_reducer=tf.math.unsorted_segment_sum,
               name="node_block"):
    """Initializes the NodeBlock module.

    Args:
      node_model_fn: A callable that will be called in the variable scope of
        this NodeBlock and should return a Sonnet module (or equivalent
        callable) to be used as the node model. The returned module should take
        a `Tensor` (of concatenated input features for each node) and return a
        `Tensor` (of output features for each node). Typically, this module
        would input and output `Tensor`s of rank 2, but it may also be input or
        output larger ranks. See the `_build` method documentation for more
        details on the acceptable inputs to this module in that case.
      use_received_edges: (bool, default=True) Whether to condition on
        aggregated edges received by each node.
      use_sent_edges: (bool, default=False) Whether to condition on aggregated
        edges sent by each node.
      use_nodes: (bool, default=True) Whether to condition on node attributes.
      use_globals: (bool, default=True) Whether to condition on global
        attributes.
      received_edges_reducer: Reduction to be used when aggregating received
        edges. This should be a callable whose signature matches
        `tf.math.unsorted_segment_sum`.
      sent_edges_reducer: Reduction to be used when aggregating sent edges.
        This should be a callable whose signature matches
        `tf.math.unsorted_segment_sum`.
      name: The module name.

    Raises:
      ValueError: When fields that are required are missing.
    """

    super(NodeBlock, self).__init__(name=name)

    if not (use_nodes or use_sent_edges or use_received_edges or use_globals):
      raise ValueError("At least one of use_received_edges, use_sent_edges, "
                       "use_nodes or use_globals must be True.")

	# 赋值参数
    self._use_received_edges = use_received_edges
    self._use_sent_edges = use_sent_edges
    self._use_nodes = use_nodes
    self._use_globals = use_globals

    with self._enter_variable_scope():
      self._node_model = node_model_fn()
      if self._use_received_edges: # 是否使用以该结点为接收点的边的信息
        if received_edges_reducer is None:
          raise ValueError(
              "If `use_received_edges==True`, `received_edges_reducer` "
              "should not be None.")
        self._received_edges_aggregator = ReceivedEdgesToNodesAggregator( # 接收结点为索引
            received_edges_reducer)
      if self._use_sent_edges: # 是否使用以该结点为发送点的边的信息
        if sent_edges_reducer is None:
          raise ValueError(
              "If `use_sent_edges==True`, `sent_edges_reducer` "
              "should not be None.")
        self._sent_edges_aggregator = SentEdgesToNodesAggregator( # 发送结点为索引
            sent_edges_reducer)

  def _build(self, graph):
    """Connects the node block.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s, whose individual edges
        features (if `use_received_edges` or `use_sent_edges` is `True`),
        individual nodes features (if `use_nodes` is True) and per graph globals
        (if `use_globals` is `True`) should be concatenable on the last axis.

    Returns:
      An output `graphs.GraphsTuple` with updated nodes.
    """

    nodes_to_collect = []

    if self._use_received_edges: # 是否使用接收结点为索引的边的信息
      nodes_to_collect.append(self._received_edges_aggregator(graph))

    if self._use_sent_edges: # 是否使用发送结点为索引的边的信息
      nodes_to_collect.append(self._sent_edges_aggregator(graph))

    if self._use_nodes: # 使用原本的结点信息
      _validate_graph(graph, (NODES,), "when use_nodes == True")
      nodes_to_collect.append(graph.nodes)

    if self._use_globals: # 使用全局变量信息
      nodes_to_collect.append(broadcast_globals_to_nodes(graph))

    collected_nodes = tf.concat(nodes_to_collect, axis=-1) # 因为第0维度为图中结点的总数 在剩余的维度上拼接信息
    updated_nodes = self._node_model(collected_nodes) # 得到新的结点特征信息
    return graph.replace(nodes=updated_nodes) # 更新图中结点特征向量


class GlobalBlock(_base.AbstractModule): # 全局变量
  """Global block.

  A block that updates the global features of each graph in a batch based on
  (a subset of) the previous global features, the aggregated features of the
  edges of the graph, and the aggregated features of the nodes of the graph.

  See https://arxiv.org/abs/1806.01261 for more details.
  """

  def __init__(self,
               global_model_fn, # sonnet 模型
               use_edges=True, # 各种参数 由字典加双星号解压
               use_nodes=True,
               use_globals=True,
               nodes_reducer=tf.math.unsorted_segment_sum,
               edges_reducer=tf.math.unsorted_segment_sum,
               name="global_block"):
    """Initializes the GlobalBlock module.

    Args:
      global_model_fn: A callable that will be called in the variable scope of
        this GlobalBlock and should return a Sonnet module (or equivalent
        callable) to be used as the global model. The returned module should
        take a `Tensor` (of concatenated input features) and return a `Tensor`
        (the global output features). Typically, this module would input and
        output `Tensor`s of rank 2, but it may also input or output larger
        ranks. See the `_build` method documentation for more details on the
        acceptable inputs to this module in that case.
      use_edges: (bool, default=True) Whether to condition on aggregated edges.
      use_nodes: (bool, default=True) Whether to condition on node attributes.
      use_globals: (bool, default=True) Whether to condition on global
        attributes.
      nodes_reducer: Reduction to be used when aggregating nodes. This should
        be a callable whose signature matches tf.math.unsorted_segment_sum.
      edges_reducer: Reduction to be used when aggregating edges. This should
        be a callable whose signature matches tf.math.unsorted_segment_sum.
      name: The module name.

    Raises:
      ValueError: When fields that are required are missing.
    """

    super(GlobalBlock, self).__init__(name=name)

    if not (use_nodes or use_edges or use_globals):
      raise ValueError("At least one of use_edges, "
                       "use_nodes or use_globals must be True.")

	# 赋值参数
    self._use_edges = use_edges
    self._use_nodes = use_nodes
    self._use_globals = use_globals

    with self._enter_variable_scope():
      self._global_model = global_model_fn() # 全局模型
      if self._use_edges: # 使用图中的边的信息
        if edges_reducer is None:
          raise ValueError(
              "If `use_edges==True`, `edges_reducer` should not be None.")
        self._edges_aggregator = EdgesToGlobalsAggregator( # 集成对应图中的所有边的信息
            edges_reducer)
      if self._use_nodes: # 使用图中的结点的信息
        if nodes_reducer is None:
          raise ValueError(
              "If `use_nodes==True`, `nodes_reducer` should not be None.")
        self._nodes_aggregator = NodesToGlobalsAggregator( # 集成对应图中的所有点的信息
            nodes_reducer)

  def _build(self, graph):
    """Connects the global block.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s, whose individual edges
        (if `use_edges` is `True`), individual nodes (if `use_nodes` is True)
        and per graph globals (if `use_globals` is `True`) should be
        concatenable on the last axis.

    Returns:
      An output `graphs.GraphsTuple` with updated globals.
    """
    globals_to_collect = []

    if self._use_edges: # 使用边的信息
      _validate_graph(graph, (EDGES,), "when use_edges == True")
      globals_to_collect.append(self._edges_aggregator(graph))

    if self._use_nodes: # 使用点的信息
      _validate_graph(graph, (NODES,), "when use_nodes == True")
      globals_to_collect.append(self._nodes_aggregator(graph))

    if self._use_globals: # 使用之前的全局信息
      _validate_graph(graph, (GLOBALS,), "when use_globals == True")
      globals_to_collect.append(graph.globals)

    collected_globals = tf.concat(globals_to_collect, axis=-1) # 第0维是图的数目 在后续维度拼接
    updated_globals = self._global_model(collected_globals) # 得到新的全局变量
    return graph.replace(globals=updated_globals) # 更新新的全局特征信息

(2)grphs.py

# Copyright 2018 The GraphNets Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""A class that defines graph-structured data.

The main purpose of the `GraphsTuple` is to represent multiple graphs with
different shapes and sizes in a way that supports batched processing.

This module first defines the string constants which are used to represent
graph(s) as tuples or dictionaries: `N_NODE, N_EDGE, NODES, EDGES, RECEIVERS,
SENDERS, GLOBALS`.

This representation could typically take the following form, for a batch of
`n_graphs` graphs stored in a `GraphsTuple` called graph:

  - N_NODE: The number of nodes per graph. It is a vector of integers with shape
    `[n_graphs]`, such that `graph.N_NODE[i]` is the number of nodes in the i-th
    graph.

  - N_EDGE: The number of edges per graph. It is a vector of integers with shape
    `[n_graphs]`, such that `graph.N_NODE[i]` is the number of edges in the i-th
    graph.

  - NODES: The nodes features. It is either `None` (the graph has no node
    features), or a vector of shape `[n_nodes] + node_shape`, where
    `n_nodes = sum(graph.N_NODE)` is the total number of nodes in the batch of
    graphs, and `node_shape` represents the shape of the features of each node.
    The relative index of a node from the batched version can be recovered from
    the `graph.N_NODE` property. For instance, the second node of the third
    graph will have its features in the
    `1 + graph.N_NODE[0] + graph.N_NODE[1]`-th slot of graph.NODES.
    Observe that having a `None` value for this field does not mean that the
    graphs have no nodes, only that they do not have features.

  - EDGES: The edges features. It is either `None` (the graph has no edge
    features), or a vector of shape `[n_edges] + edge_shape`, where
    `n_edges = sum(graph.N_EDGE)` is the total number of edges in the batch of
    graphs, and `edge_shape` represents the shape of the features of each edge.
    The relative index of an edge from the batched version can be recovered from
    the `graph.N_EDGE` property. For instance, the third edge of the third
    graph will have its features in the `2 + graph.N_EDGE[0] + graph.N_EDGE[1]`-
    th slot of graph.EDGES.
    Observe that having a `None` value for this field does not necessarily mean
    that the graph has no edges, only that they do not have features.

  - RECEIVERS: The indices of the receiver nodes, for each edge. It is either
    `None` (if the graph has no edges), or a vector of integers of shape
    `[n_edges]`, such that `graph.RECEIVERS[i]` is the index of the node
    receiving from the i-th edge.
    Observe that the index is absolute (in other words, cumulative), i.e.
    `graphs.RECEIVERS` take value in `[0, n_nodes]`. For instance, an edge
    connecting the vertices with relative indices 2 and 3 in the second graph of
    the batch would have a `RECEIVERS` value of `3 + graph.N_NODE[0]`.
    If `graphs.RECEIVERS` is `None`, then `graphs.EDGES` and `graphs.SENDERS`
    should also be `None`.

  - SENDERS: The indices of the sender nodes, for each edge. It is either
    `None` (if the graph has no edges), or a vector of integers of shape
    `[n_edges]`, such that `graph.SENDERS[i]` is the index of the node
    sending from the i-th edge.
    Observe that the index is absolute, i.e. `graphs.RECEIVERS` take value in
    `[0, n_nodes]`. For instance, an edge connecting the vertices with relative
    indices 1 and 3 in the third graph of the batch would have a `SENDERS` value
    of `1 + graph.N_NODE[0] + graph.N_NODE[1]`.
    If `graphs.SENDERS` is `None`, then `graphs.EDGES` and `graphs.RECEIVERS`
    should also be `None`.

  - GLOBALS: The global features of the graph. It is either `None` (the graph
    has no global features), or a vector of shape `[n_graphs] + global_shape`
    representing graph level features.

The `utils_np` and `utils_tf` modules provide convenience methods to work with
graph that contain numpy and tensorflow data, respectively: conversion,
batching, unbatching, indexing, among others.

The `GraphsTuple` class, however, is not restricted to storing vectors, and can
be used to store attributes of graphs as well (for instance, types or shapes).

The only assertions it makes are that the `None` fields are compatible with the
definition of a graph given above, namely:

  - the N_NODE and N_EDGE fields cannot be `None`;

  - if RECEIVERS is None, then SENDERS must be `None` (and vice-versa);

  - if RECEIVERS and SENDERS are `None`, then `EDGES` must be `None`.

Those assumptions are checked both upon initialization and when replacing a
field by calling the `replace` or `map` method.
"""

"""
GraphsTuple的主要目的是以支持批处理的方式表示具有不同形状和大小的多个图

该模块首先定义用于将图表示为元组或字典的字符串常量:
N_NODE,N_EDGE,NODES,EDGES,RECEIVERS,SENDERS,GLOBALS

N_NODE:每个图的节点数向量,它是形状为[n_graphs],因此graph.N_NODE[i]是第i个图中的节点数

N_EDGE:每个图的边数向量,它是形状为[n_graphs],因此graph.N_EDGE[i]是第i个图的边数

NODES:多维向量,第一维向量表示节点总数,可用N_NODE切分各图,剩余维度表示节点特征
例如,第三个图的第二个节点将在graph.NODES的第1+graph.N_NODE[0]+graph.N_NODE[1]个序号中具有其特征

EDGES:多维向量,第一维向量表示边总数,可用N_EDGE切分各图,剩余维度表示边特征
例如,第三个图的第三个边将在graph.EDGES的第2+graph.N_NODE[0]+graph.N_NODE[1]个序号中具有其特征

RECEIVERS:一维向量,长度为N_EDGE的长度,表示的是第i条边的接受者,如 2->3 其值就为 3

SENDERS:一维向量,长度为N_EDGE的长度,表示的是第i条边的发出者,如 2->3 其值就为 2

GLOBALS:全局特征变量,第一个维度表示图的数目,剩余维度表示其特征,可用 n_graphs 切分访问

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

NODES = "nodes"
EDGES = "edges"
RECEIVERS = "receivers"
SENDERS = "senders"
GLOBALS = "globals"
N_NODE = "n_node"
N_EDGE = "n_edge"

GRAPH_FEATURE_FIELDS = (NODES, EDGES, GLOBALS)
GRAPH_INDEX_FIELDS = (RECEIVERS, SENDERS)
GRAPH_DATA_FIELDS = (NODES, EDGES, RECEIVERS, SENDERS, GLOBALS)
GRAPH_NUMBER_FIELDS = (N_NODE, N_EDGE)
ALL_FIELDS = (NODES, EDGES, RECEIVERS, SENDERS, GLOBALS, N_NODE, N_EDGE)


class GraphsTuple(
    collections.namedtuple("GraphsTuple",
                           GRAPH_DATA_FIELDS + GRAPH_NUMBER_FIELDS)): # 定义元组子类名 以及字典形式的键名
  """Default namedtuple describing `Graphs`s.

  A children of `collections.namedtuple`s, which allows it to be directly input
  and output from `tensorflow.Session.run()` calls.

  An instance of this class can be constructed as
  ```
  GraphsTuple(nodes=nodes,
              edges=edges,
              globals=globals,
              receivers=receivers,
              senders=senders,
              n_node=n_node,
              n_edge=n_edge)
  ```
  where `nodes`, `edges`, `globals`, `receivers`, `senders`, `n_node` and
  `n_edge` are arbitrary, but are typically numpy arrays, tensors, or `None`;
  see module's documentation for a more detailed description of which fields
  can be left `None`.
  """

  def _validate_none_fields(self):
    """Asserts that the set of `None` fields in the instance is valid."""
    if self.n_node is None:
      raise ValueError("Field `n_node` cannot be None")
    if self.n_edge is None:
      raise ValueError("Field `n_edge` cannot be None")
    if self.receivers is None and self.senders is not None:
      raise ValueError(
          "Field `senders` must be None as field `receivers` is None")
    if self.senders is None and self.receivers is not None:
      raise ValueError(
          "Field `receivers` must be None as field `senders` is None")
    if self.receivers is None and self.edges is not None:
      raise ValueError(
          "Field `edges` must be None as field `receivers` and `senders` are "
          "None")

  def __init__(self, *args, **kwargs):
    del args, kwargs
    # The fields of a `namedtuple` are filled in the `__new__` method.
    # `__init__` does not accept parameters.
    super(GraphsTuple, self).__init__()
    self._validate_none_fields()

  def replace(self, **kwargs):
    output = self._replace(**kwargs) # 返回一个新的实例 
    output._validate_none_fields()  # pylint: disable=protected-access 验证返回的新实例是否满足要求
    return output

  def map(self, field_fn, fields=GRAPH_FEATURE_FIELDS): # 对每个键应用函数
    """Applies `field_fn` to the fields `fields` of the instance.

    `field_fn` is applied exactly once per field in `fields`. The result must
    satisfy the `GraphsTuple` requirement w.r.t. `None` fields, i.e. the
    `SENDERS` cannot be `None` if the `EDGES` or `RECEIVERS` are not `None`,
    etc.

    Args:
      field_fn: A callable that take a single argument.
      fields: (iterable of `str`). An iterable of the fields to apply
        `field_fn` to.

    Returns:
      A copy of the instance, with the fields in `fields` replaced by the result
      of applying `field_fn` to them.
    """
    return self.replace(**{k: field_fn(getattr(self, k)) for k in fields}) # getattr(self, k) 获取的是键值对中的值,  k表示键

(3)modules.py

# Copyright 2018 The GraphNets Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Common Graph Network architectures.

The modules in this files are Sonnet modules that:

  - take a `graphs.GraphsTuple` containing `Tensor`s as input, with possibly
    `None` fields (depending on the module);

  - return a `graphs.GraphsTuple` with updated values for some fields
    (depending on the module).


The provided modules are:

  - `GraphNetwork`: a general purpose Graph Network composed of configurable
    `EdgeBlock`, `NodeBlock` and `GlobalBlock` from `blocks.py`;

  - `GraphIndependent`: a Graph Network producing updated edges (resp. nodes,
    globals) based on the input's edges (resp. nodes, globals) only;

  - `InteractionNetwork` (from https://arxiv.org/abs/1612.00222): a
    network propagating information on the edges and nodes of a graph;

  - RelationNetwork (from https://arxiv.org/abs/1706.01427): a network
    updating the global property based on the relation between the input's
    nodes properties;

  - DeepSets (from https://arxiv.org/abs/1703.06114): a network that operates on
    sets (graphs without edges);

  - CommNet (from https://arxiv.org/abs/1605.07736 and
    https://arxiv.org/abs/1706.06122): a network updating nodes based on their
    previous features and the features of the adjacent nodes.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from graph_nets import _base
from graph_nets import blocks
import tensorflow as tf

_DEFAULT_EDGE_BLOCK_OPT = { # 边模块选项
    "use_edges": True,
    "use_receiver_nodes": True,
    "use_sender_nodes": True,
    "use_globals": True,
}

_DEFAULT_NODE_BLOCK_OPT = { # 点模块选项
    "use_received_edges": True,
    "use_sent_edges": False,
    "use_nodes": True,
    "use_globals": True,
}

_DEFAULT_GLOBAL_BLOCK_OPT = { # 全局变量选项
    "use_edges": True,
    "use_nodes": True,
    "use_globals": True,
}


class InteractionNetwork(_base.AbstractModule):
  """Implementation of an Interaction Network.

  An interaction networks computes interactions on the edges based on the
  previous edges features, and on the features of the nodes sending into those
  edges. It then updates the nodes based on the incomming updated edges.
  See https://arxiv.org/abs/1612.00222 for more details.

  This model does not update the graph globals, and they are allowed to be
  `None`.
  """

  def __init__(self,
               edge_model_fn,
               node_model_fn,
               reducer=tf.math.unsorted_segment_sum,
               name="interaction_network"):
    """Initializes the InteractionNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to `EdgeBlock` to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.EdgeBlock` for details), and the shape of the
        output of this module must match the one of the input nodes, but for the
        first and last axis.
      node_model_fn: A callable that will be passed to `NodeBlock` to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.NodeBlock` for details).
      reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to
        tf.math.unsorted_segment_sum.
      name: The module name.
    """
    super(InteractionNetwork, self).__init__(name=name)

    with self._enter_variable_scope():
      self._edge_block = blocks.EdgeBlock(
          edge_model_fn=edge_model_fn, use_globals=False)
      self._node_block = blocks.NodeBlock(
          node_model_fn=node_model_fn,
          use_sent_edges=False,
          use_globals=False,
          received_edges_reducer=reducer)

  def _build(self, graph):
    """Connects the InterationNetwork.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s. `graph.globals` can be
        `None`. The features of each node and edge of `graph` must be
        concatenable on the last axis (i.e., the shapes of `graph.nodes` and
        `graph.edges` must match but for their first and last axis).

    Returns:
      An output `graphs.GraphsTuple` with updated edges and nodes.

    Raises:
      ValueError: If any of `graph.nodes`, `graph.edges`, `graph.receivers` or
        `graph.senders` is `None`.
    """
    return self._node_block(self._edge_block(graph))


class RelationNetwork(_base.AbstractModule):
  """Implementation of a Relation Network.

  See https://arxiv.org/abs/1706.01427 for more details.

  The global and edges features of the input graph are not used, and are
  allowed to be `None` (the receivers and senders properties must be present).
  The output graph has updated, non-`None`, globals.
  """

  def __init__(self,
               edge_model_fn,
               global_model_fn,
               reducer=tf.math.unsorted_segment_sum,
               name="relation_network"):
    """Initializes the RelationNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to
        tf.math.unsorted_segment_sum.
      name: The module name.
    """
    super(RelationNetwork, self).__init__(name=name)

    with self._enter_variable_scope():
      self._edge_block = blocks.EdgeBlock(
          edge_model_fn=edge_model_fn,
          use_edges=False,
          use_receiver_nodes=True,
          use_sender_nodes=True,
          use_globals=False)

      self._global_block = blocks.GlobalBlock(
          global_model_fn=global_model_fn,
          use_edges=True,
          use_nodes=False,
          use_globals=False,
          edges_reducer=reducer)

  def _build(self, graph):
    """Connects the RelationNetwork.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s, except for the edges
        and global properties which may be `None`.

    Returns:
      A `graphs.GraphsTuple` with updated globals.

    Raises:
      ValueError: If any of `graph.nodes`, `graph.receivers` or `graph.senders`
        is `None`.
    """
    output_graph = self._global_block(self._edge_block(graph))
    return graph.replace(globals=output_graph.globals)


def _make_default_edge_block_opt(edge_block_opt): # 返回通过边模型时所用到的信息选项
  """Default options to be used in the EdgeBlock of a generic GraphNetwork."""
  edge_block_opt = dict(edge_block_opt.items()) if edge_block_opt else {} # 为空则返回空字典
  for k, v in _DEFAULT_EDGE_BLOCK_OPT.items(): # 对于默认的选项
    edge_block_opt[k] = edge_block_opt.get(k, v) # 如果自己定义的键不存在默认的选项中,则使用默认选项
  return edge_block_opt


def _make_default_node_block_opt(node_block_opt, default_reducer): # 返回通过点模型时所用到的信息选项
  """Default options to be used in the NodeBlock of a generic GraphNetwork."""
  node_block_opt = dict(node_block_opt.items()) if node_block_opt else {} # 为空则返回空字典
  for k, v in _DEFAULT_NODE_BLOCK_OPT.items(): # 对于默认的选项
    node_block_opt[k] = node_block_opt.get(k, v) # 如果自己定义的键不存在默认的选项中,则使用默认选项
  for key in ["received_edges_reducer", "sent_edges_reducer"]: # 此处需要额外设置
    node_block_opt[key] = node_block_opt.get(key, default_reducer) # 指定集成信息的方法,不存在则使用默认的方法
  return node_block_opt


def _make_default_global_block_opt(global_block_opt, default_reducer): # 返回通过点模型时所用到的信息选项
  """Default options to be used in the GlobalBlock of a generic GraphNetwork."""
  global_block_opt = dict(global_block_opt.items()) if global_block_opt else {} # 为空则返回空字典
  for k, v in _DEFAULT_GLOBAL_BLOCK_OPT.items(): # 对于默认的选项
    global_block_opt[k] = global_block_opt.get(k, v) # 如果自己定义的键不存在默认的选项中,则使用默认选项
  for key in ["edges_reducer", "nodes_reducer"]: # 此处需要额外设置
    global_block_opt[key] = global_block_opt.get(key, default_reducer) # 指定集成信息的方法,不存在则使用默认的方法
  return global_block_opt


class GraphNetwork(_base.AbstractModule): # 各种信息均用到
  """Implementation of a Graph Network.

  See https://arxiv.org/abs/1806.01261 for more details.
  """

  def __init__(self,
               edge_model_fn, # 边模型
               node_model_fn, # 点模型
               global_model_fn, # 全局信息模型
               reducer=tf.math.unsorted_segment_sum, # 对于信息的集成
               edge_block_opt=None, # 经过边模型使用哪些信息
               node_block_opt=None, # 经过点模型时使用哪些信息
               global_block_opt=None, # 全局变量模型使用哪些信息
               name="graph_network"):
    """Initializes the GraphNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      node_model_fn: A callable that will be passed to NodeBlock to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see NodeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by NodeBlock and GlobalBlock to aggregate
        nodes and edges. Defaults to tf.math.unsorted_segment_sum. This will be
        overridden by the reducers specified in `node_block_opt` and
        `global_block_opt`, if any.
      edge_block_opt: Additional options to be passed to the EdgeBlock. Can
        contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`,
        `use_globals`. By default, these are all True.
      node_block_opt: Additional options to be passed to the NodeBlock. Can
        contain the keys `use_received_edges`, `use_nodes`, `use_globals` (all
        set to True by default), `use_sent_edges` (defaults to False), and
        `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`).
      global_block_opt: Additional options to be passed to the GlobalBlock. Can
        contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to
        True by default), and `edges_reducer`, `nodes_reducer` (defaults to
        `reducer`).
      name: The module name.
    """
    super(GraphNetwork, self).__init__(name=name)
    # 设置每个模型所需要用到的信息
    edge_block_opt = _make_default_edge_block_opt(edge_block_opt)
    node_block_opt = _make_default_node_block_opt(node_block_opt, reducer)
    global_block_opt = _make_default_global_block_opt(global_block_opt, reducer)

	# 设置自身点边全局模型
    with self._enter_variable_scope(): # ** 表示将其解开为单独的值访问
      self._edge_block = blocks.EdgeBlock( # 构建边模型
          edge_model_fn=edge_model_fn, **edge_block_opt)
      self._node_block = blocks.NodeBlock( # 构建点模型
          node_model_fn=node_model_fn, **node_block_opt)
      self._global_block = blocks.GlobalBlock( # 构建全局模型
          global_model_fn=global_model_fn, **global_block_opt)

  def _build(self, graph):
    """Connects the GraphNetwork.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s. Depending on the block
        options, `graph` may contain `None` fields; but with the default
        configuration, no `None` field is allowed. Moreover, when using the
        default configuration, the features of each nodes, edges and globals of
        `graph` should be concatenable on the last dimension.

    Returns:
      An output `graphs.GraphsTuple` with updated edges, nodes and globals.
    """
    return self._global_block(self._node_block(self._edge_block(graph)))


class GraphIndependent(_base.AbstractModule):
  """A graph block that applies models to the graph elements independently.

  The inputs and outputs are graphs. The corresponding models are applied to
  each element of the graph (edges, nodes and globals) in parallel and
  independently of the other elements. It can be used to encode or
  decode the elements of a graph.
  """

  def __init__(self,
               edge_model_fn=None,
               node_model_fn=None,
               global_model_fn=None,
               name="graph_independent"):
    """Initializes the GraphIndependent module.

    Args:
      edge_model_fn: A callable that returns an edge model function. The
        callable must return a Sonnet module (or equivalent). If passed `None`,
        will pass through inputs (the default).
      node_model_fn: A callable that returns a node model function. The callable
        must return a Sonnet module (or equivalent). If passed `None`, will pass
        through inputs (the default).
      global_model_fn: A callable that returns a global model function. The
        callable must return a Sonnet module (or equivalent). If passed `None`,
        will pass through inputs (the default).
      name: The module name.
    """
    super(GraphIndependent, self).__init__(name=name)

    with self._enter_variable_scope():
      # The use of snt.Module below is to ensure the ops and variables that
      # result from the edge/node/global_model_fns are scoped analogous to how
      # the Edge/Node/GlobalBlock classes do.
      if edge_model_fn is None:
        self._edge_model = lambda x: x
      else:
        self._edge_model = _base.WrappedModelFnModule(
            edge_model_fn, name="edge_model")
      if node_model_fn is None:
        self._node_model = lambda x: x
      else:
        self._node_model = _base.WrappedModelFnModule(
            node_model_fn, name="node_model")
      if global_model_fn is None:
        self._global_model = lambda x: x
      else:
        self._global_model = _base.WrappedModelFnModule(
            global_model_fn, name="global_model")

  def _build(self, graph):
    """Connects the GraphIndependent.

    Args:
      graph: A `graphs.GraphsTuple` containing non-`None` edges, nodes and
        globals.

    Returns:
      An output `graphs.GraphsTuple` with updated edges, nodes and globals.

    """
    return graph.replace(
        edges=self._edge_model(graph.edges),
        nodes=self._node_model(graph.nodes),
        globals=self._global_model(graph.globals))


class DeepSets(_base.AbstractModule):
  """DeepSets module.

  Implementation for the model described in https://arxiv.org/abs/1703.06114
  (M. Zaheer, S. Kottur, S. Ravanbakhsh, B. Poczos, R. Salakhutdinov, A. Smola).
  See also PointNet (https://arxiv.org/abs/1612.00593, C. Qi, H. Su, K. Mo,
  L. J. Guibas) for a related model.

  This module operates on sets, which can be thought of as graphs without
  edges. The nodes features are first updated based on their value and the
  globals features, and new globals features are then computed based on the
  updated nodes features.

  Note that in the original model, only the globals are updated in the returned
  graph, while this implementation also returns updated nodes.
  The original model can be reproduced by writing:
  ```
  deep_sets = DeepSets()
  output = deep_sets(input)
  output = input.replace(globals=output.globals)
  ```

  This module does not use the edges data or the information contained in the
  receivers or senders; the output graph has the same value in those fields as
  the input graph. Those fields can also have `None` values in the input
  `graphs.GraphsTuple`.
  """

  def __init__(self,
               node_model_fn,
               global_model_fn,
               reducer=tf.math.unsorted_segment_sum,
               name="deep_sets"):
    """Initializes the DeepSets module.

    Args:
      node_model_fn: A callable to be passed to NodeBlock. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details). The
        shape of this module's output must equal the shape of the input graph's
        global features, but for the first and last axis.
      global_model_fn: A callable to be passed to GlobalBlock. The callable must
        return a Sonnet module (or equivalent; see GlobalBlock for details).
      reducer: Reduction to be used when aggregating the nodes in the globals.
        This should be a callable whose signature matches
        tf.math.unsorted_segment_sum.
      name: The module name.
    """
    super(DeepSets, self).__init__(name=name)

    with self._enter_variable_scope():
      self._node_block = blocks.NodeBlock(
          node_model_fn=node_model_fn,
          use_received_edges=False,
          use_sent_edges=False,
          use_nodes=True,
          use_globals=True)
      self._global_block = blocks.GlobalBlock(
          global_model_fn=global_model_fn,
          use_edges=False,
          use_nodes=True,
          use_globals=False,
          nodes_reducer=reducer)

  def _build(self, graph):
    """Connects the DeepSets network.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s, whose edges, senders
        or receivers properties may be `None`. The features of every node and
        global of `graph` should be concatenable on the last axis (i.e. the
        shapes of `graph.nodes` and `graph.globals` must match but for their
        first and last axis).

    Returns:
      An output `graphs.GraphsTuple` with updated globals.
    """
    return self._global_block(self._node_block(graph))


class CommNet(_base.AbstractModule):
  """CommNet module.

  Implementation for the model originally described in
  https://arxiv.org/abs/1605.07736 (S. Sukhbaatar, A. Szlam, R. Fergus), in the
  version presented in https://arxiv.org/abs/1706.06122 (Y. Hoshen).

  This module internally creates edge features based on the features from the
  nodes sending to that edge, and independently learns an embedding for each
  node. It then uses these edges and nodes features to compute updated node
  features.

  This module does not use the global nor the edges features of the input, but
  uses its receivers and senders information. The output graph has the same
  value in edge and global fields as the input graph. The edge and global
  features fields may have a `None` value in the input `gn_graphs.GraphsTuple`.
  """

  def __init__(self,
               edge_model_fn,
               node_encoder_model_fn,
               node_model_fn,
               reducer=tf.math.unsorted_segment_sum,
               name="comm_net"):
    """Initializes the CommNet module.

    Args:
      edge_model_fn: A callable to be passed to EdgeBlock. The callable must
        return a Sonnet module (or equivalent; see EdgeBlock for details).
      node_encoder_model_fn: A callable to be passed to the NodeBlock
        responsible for the first encoding of the nodes. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details). The
        shape of this module's output should match the shape of the module built
        by `edge_model_fn`, but for the first and last dimension.
      node_model_fn: A callable to be passed to NodeBlock. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details).
      reducer: Reduction to be used when aggregating the edges in the nodes.
        This should be a callable whose signature matches
        tf.math.unsorted_segment_sum.
      name: The module name.
    """
    super(CommNet, self).__init__(name=name)

    with self._enter_variable_scope():
      # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122
      self._edge_block = blocks.EdgeBlock(
          edge_model_fn=edge_model_fn,
          use_edges=False,
          use_receiver_nodes=False,
          use_sender_nodes=True,
          use_globals=False)
      # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122
      self._node_encoder_block = blocks.NodeBlock(
          node_model_fn=node_encoder_model_fn,
          use_received_edges=False,
          use_sent_edges=False,
          use_nodes=True,
          use_globals=False,
          received_edges_reducer=reducer,
          name="node_encoder_block")
      # Computes $\Theta(..)$ in Eq.(2) of 1706.06122
      self._node_block = blocks.NodeBlock(
          node_model_fn=node_model_fn,
          use_received_edges=True,
          use_sent_edges=False,
          use_nodes=True,
          use_globals=False,
          received_edges_reducer=reducer)

  def _build(self, graph):
    """Connects the CommNet network.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s, with non-`None` nodes,
        receivers and senders.

    Returns:
      An output `graphs.GraphsTuple` with updated nodes.

    Raises:
      ValueError: if any of `graph.nodes`, `graph.receivers` or `graph.senders`
      is `None`.
    """
    node_input = self._node_encoder_block(self._edge_block(graph))
    return graph.replace(nodes=self._node_block(node_input).nodes)


def _unsorted_segment_softmax(data,
                              segment_ids,
                              num_segments,
                              name="unsorted_segment_softmax"):
  """Performs an elementwise softmax operation along segments of a tensor.

  The input parameters are analogous to `tf.math.unsorted_segment_sum`. It
  produces an output of the same shape as the input data, after performing an
  elementwise sofmax operation between all of the rows with common segment id.

  Args:
    data: A tensor with at least one dimension.
    segment_ids: A tensor of indices segmenting `data` across the first
      dimension.
    num_segments: A scalar tensor indicating the number of segments. It should
      be at least `max(segment_ids) + 1`.
    name: A name for the operation (optional).

  Returns:
    A tensor with the same shape as `data` after applying the softmax operation.

  """
  with tf.name_scope(name):
    segment_maxes = tf.math.unsorted_segment_max(data, segment_ids,
                                                 num_segments)
    maxes = tf.gather(segment_maxes, segment_ids)
    # Possibly refactor to `tf.stop_gradient(maxes)` for better performance.
    data -= maxes
    exp_data = tf.exp(data)
    segment_sum_exp_data = tf.math.unsorted_segment_sum(exp_data, segment_ids,
                                                        num_segments)
    sum_exp_data = tf.gather(segment_sum_exp_data, segment_ids)
    return exp_data / sum_exp_data


def _received_edges_normalizer(graph,
                               normalizer,
                               name="received_edges_normalizer"):
  """Performs elementwise normalization for all received edges by a given node.

  Args:
    graph: A graph containing edge information.
    normalizer: A normalizer function following the signature of
      `modules._unsorted_segment_softmax`.
    name: A name for the operation (optional).

  Returns:
    A tensor with the resulting normalized edges.

  """
  with tf.name_scope(name):
    return normalizer(
        data=graph.edges,
        segment_ids=graph.receivers,
        num_segments=tf.reduce_sum(graph.n_node))


class SelfAttention(_base.AbstractModule):
  """Multi-head self-attention module.

  The module is based on the following three papers:
   * A simple neural network module for relational reasoning (RNs):
       https://arxiv.org/abs/1706.01427
   * Non-local Neural Networks: https://arxiv.org/abs/1711.07971.
   * Attention Is All You Need (AIAYN): https://arxiv.org/abs/1706.03762.

  The input to the modules consists of a graph containing values for each node
  and connectivity between them, a tensor containing keys for each node
  and a tensor containing queries for each node.

  The self-attention step consist of updating the node values, with each new
  node value computed in a two step process:
  - Computing the attention weights between each node and all of its senders
   nodes, by calculating sum(sender_key*receiver_query) and using the softmax
   operation on all attention weights for each node.
  - For each receiver node, compute the new node value as the weighted average
   of the values of the sender nodes, according to the attention weights.
  - Nodes with no received edges, get an updated value of 0.

  Values, keys and queries contain a "head" axis to compute independent
  self-attention for each of the heads.

  """

  def __init__(self, name="self_attention"):
    """Inits the module.

    Args:
      name: The module name.
    """
    super(SelfAttention, self).__init__(name=name)
    self._normalizer = _unsorted_segment_softmax

  def _build(self, node_values, node_keys, node_queries, attention_graph):
    """Connects the multi-head self-attention module.

    The self-attention is only computed according to the connectivity of the
    input graphs, with receiver nodes attending to sender nodes.

    Args:
      node_values: Tensor containing the values associated to each of the nodes.
        The expected shape is [total_num_nodes, num_heads, key_size].
      node_keys: Tensor containing the key associated to each of the nodes. The
        expected shape is [total_num_nodes, num_heads, key_size].
      node_queries: Tensor containing the query associated to each of the nodes.
        The expected shape is [total_num_nodes, num_heads, query_size]. The
        query size must be equal to the key size.
      attention_graph: Graph containing connectivity information between nodes
        via the senders and receivers fields. Node A will only attempt to attend
        to Node B if `attention_graph` contains an edge sent by Node A and
        received by Node B.

    Returns:
      An output `graphs.GraphsTuple` with updated nodes containing the
      aggregated attended value for each of the nodes with shape
      [total_num_nodes, num_heads, value_size].

    Raises:
      ValueError: if the input graph does not have edges.
    """

    # Sender nodes put their keys and values in the edges.
    # [total_num_edges, num_heads, query_size]
    sender_keys = blocks.broadcast_sender_nodes_to_edges(
        attention_graph.replace(nodes=node_keys))
    # [total_num_edges, num_heads, value_size]
    sender_values = blocks.broadcast_sender_nodes_to_edges(
        attention_graph.replace(nodes=node_values))

    # Receiver nodes put their queries in the edges.
    # [total_num_edges, num_heads, key_size]
    receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
        attention_graph.replace(nodes=node_queries))

    # Attention weight for each edge.
    # [total_num_edges, num_heads]
    attention_weights_logits = tf.reduce_sum(
        sender_keys * receiver_queries, axis=-1)
    normalized_attention_weights = _received_edges_normalizer(
        attention_graph.replace(edges=attention_weights_logits),
        normalizer=self._normalizer)

    # Attending to sender values according to the weights.
    # [total_num_edges, num_heads, embedding_size]
    attented_edges = sender_values * normalized_attention_weights[..., None]

    # Summing all of the attended values from each node.
    # [total_num_nodes, num_heads, embedding_size]
    received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
        reducer=tf.math.unsorted_segment_sum)
    aggregated_attended_values = received_edges_aggregator(
        attention_graph.replace(edges=attented_edges))

    return attention_graph.replace(nodes=aggregated_attended_values)

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值