(1)blocks.py
# Copyright 2018 The GraphNets Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Building blocks for Graph Networks.
This module contains elementary building blocks of graph networks:
- `broadcast_{field_1}_to_{field_2}` propagates the features from `field_1`
onto the relevant elements of `field_2`;
- `{field_1}To{field_2}Aggregator` propagates and then reduces the features
from `field_1` onto the relevant elements of `field_2`;
- the `EdgeBlock`, `NodeBlock` and `GlobalBlock` are elementary graph networks
that only update the edges (resp. the nodes, the globals) of their input
graph (as described in https://arxiv.org/abs/1806.01261).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from graph_nets import _base
from graph_nets import graphs
from graph_nets import utils_tf
import tensorflow as tf
NODES = graphs.NODES
EDGES = graphs.EDGES
GLOBALS = graphs.GLOBALS
RECEIVERS = graphs.RECEIVERS
SENDERS = graphs.SENDERS
GLOBALS = graphs.GLOBALS
N_NODE = graphs.N_NODE
N_EDGE = graphs.N_EDGE
def _validate_graph(graph, mandatory_fields, additional_message=None):
for field in mandatory_fields:
if getattr(graph, field) is None:
message = "`{}` field cannot be None".format(field)
if additional_message:
message += " " + format(additional_message)
message += "."
raise ValueError(message)
def _validate_broadcasted_graph(graph, from_field, to_field):
additional_message = "when broadcasting {} to {}".format(from_field, to_field)
_validate_graph(graph, [from_field, to_field], additional_message)
def broadcast_globals_to_edges(graph, name="broadcast_globals_to_edges"): # 每个图的全局变量复制该图中边的总个数 最后再0维度上拼接
"""Broadcasts the global features to the edges of a graph.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of
shape `[n_graphs] + global_shape`, and `N_EDGE` field of shape
`[n_graphs]`.
name: (string, optional) A name for the operation.
Returns:
A tensor of shape `[n_edges] + global_shape`, where
`n_edges = sum(graph.n_edge)`. The i-th element of this tensor is given by
`globals[j]`, where j is the index of the graph the i-th edge belongs to
(i.e. is such that
`sum_{k < j} graphs.n_edge[k] <= i < sum_{k <= j} graphs.n_edge[k]`).
Raises:
ValueError: If either `graph.globals` or `graph.n_edge` is `None`.
"""
_validate_broadcasted_graph(graph, GLOBALS, N_EDGE) # ValueError: If either `graph.globals` or `graph.n_edge` is `None`
with tf.name_scope(name): # 给变量添加前缀名
return utils_tf.repeat(graph.globals, graph.n_edge, axis=0) # 每个图的全局变量复制该图中边的总个数 最后再0维度上拼接
def broadcast_globals_to_nodes(graph, name="broadcast_globals_to_nodes"): # 每个图的全局变量复制该图中结点的总个数 最后再0维度上拼接
"""Broadcasts the global features to the nodes of a graph.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of
shape `[n_graphs] + global_shape`, and `N_NODE` field of shape
`[n_graphs]`.
name: (string, optional) A name for the operation.
Returns:
A tensor of shape `[n_nodes] + global_shape`, where
`n_nodes = sum(graph.n_node)`. The i-th element of this tensor is given by
`globals[j]`, where j is the index of the graph the i-th node belongs to
(i.e. is such that
`sum_{k < j} graphs.n_node[k] <= i < sum_{k <= j} graphs.n_node[k]`).
Raises:
ValueError: If either `graph.globals` or `graph.n_node` is `None`.
"""
_validate_broadcasted_graph(graph, GLOBALS, N_NODE) # ValueError: If either `graph.globals` or `graph.n_node` is `None`
with tf.name_scope(name): # 给变量添加前缀名
return utils_tf.repeat(graph.globals, graph.n_node, axis=0) # 每个图的全局变量复制该图中结点的总个数 最后再0维度上拼接
def broadcast_sender_nodes_to_edges( # 存储每条边的发送结点的特征信息
graph, name="broadcast_sender_nodes_to_edges"):
"""Broadcasts the node features to the edges they are sending into.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, with nodes features of
shape `[n_nodes] + node_shape`, and `senders` field of shape
`[n_edges]`.
name: (string, optional) A name for the operation.
Returns:
A tensor of shape `[n_edges] + node_shape`, where
`n_edges = sum(graph.n_edge)`. The i-th element is given by
`graph.nodes[graph.senders[i]]`.
Raises:
ValueError: If either `graph.nodes` or `graph.senders` is `None`.
"""
_validate_broadcasted_graph(graph, NODES, SENDERS) # ValueError: If either `graph.nodes` or `graph.senders` is `None`
with tf.name_scope(name): # 给变量添加前缀名
return tf.gather(graph.nodes, graph.senders) # 返回的是每条边的发送结点的结点特征向量
def broadcast_receiver_nodes_to_edges( # 返回的是每条边的接收结点的结点特征向量
graph, name="broadcast_receiver_nodes_to_edges"):
"""Broadcasts the node features to the edges they are receiving from.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, with nodes features of
shape `[n_nodes] + node_shape`, and receivers of shape `[n_edges]`.
name: (string, optional) A name for the operation.
Returns:
A tensor of shape `[n_edges] + node_shape`, where
`n_edges = sum(graph.n_edge)`. The i-th element is given by
`graph.nodes[graph.receivers[i]]`.
Raises:
ValueError: If either `graph.nodes` or `graph.receivers` is `None`.
"""
_validate_broadcasted_graph(graph, NODES, RECEIVERS) # ValueError: If either `graph.nodes` or `graph.receivers` is `None`
with tf.name_scope(name): # 给变量添加前缀名
return tf.gather(graph.nodes, graph.receivers) # 返回的是每条边的接收结点的结点特征向量
class EdgesToGlobalsAggregator(_base.AbstractModule): # 将每个图的对应所有边信息相加 返回的第一维度是图的个数
"""Aggregates all edges into globals."""
def __init__(self, reducer, name="edges_to_globals_aggregator"):
"""Initializes the EdgesToGlobalsAggregator module.
The reducer is used for combining per-edge features (one set of edge
feature vectors per graph) to give per-graph features (one feature
vector per graph). The reducer should take a `Tensor` of edge features, a
`Tensor` of segment indices, and a number of graphs. It should be invariant
under permutation of edge features within each graph.
Examples of compatible reducers are:
* tf.math.unsorted_segment_sum
* tf.math.unsorted_segment_mean
* tf.math.unsorted_segment_prod
* unsorted_segment_min_or_zero
* unsorted_segment_max_or_zero
Args:
reducer: A function for reducing sets of per-edge features to individual
per-graph features.
name: The module name.
"""
super(EdgesToGlobalsAggregator, self).__init__(name=name)
self._reducer = reducer
def _build(self, graph):
_validate_graph(graph, (EDGES,), # 验证所需信息的依赖是否满足
additional_message="when aggregating from edges.")
num_graphs = utils_tf.get_num_graphs(graph) # 返回图中图的个数
graph_index = tf.range(num_graphs) # 生成列表 tensor
indices = utils_tf.repeat(graph_index, graph.n_edge, axis=0) # 每个图序号重复该图中边的总数 在一维展开
return self._reducer(graph.edges, indices, num_graphs) # 将每个图的对应所有边信息相加 返回的第一维度是图的个数
class NodesToGlobalsAggregator(_base.AbstractModule): # 将每个图的对应所有点信息相加 返回的第0维度是图的个数
"""Aggregates all nodes into globals."""
def __init__(self, reducer, name="nodes_to_globals_aggregator"):
"""Initializes the NodesToGlobalsAggregator module.
The reducer is used for combining per-node features (one set of node
feature vectors per graph) to give per-graph features (one feature
vector per graph). The reducer should take a `Tensor` of node features, a
`Tensor` of segment indices, and a number of graphs. It should be invariant
under permutation of node features within each graph.
Examples of compatible reducers are:
* tf.math.unsorted_segment_sum
* tf.math.unsorted_segment_mean
* tf.math.unsorted_segment_prod
* unsorted_segment_min_or_zero
* unsorted_segment_max_or_zero
Args:
reducer: A function for reducing sets of per-node features to individual
per-graph features.
name: The module name.
"""
super(NodesToGlobalsAggregator, self).__init__(name=name)
self._reducer = reducer # 集成信息方法
def _build(self, graph):
_validate_graph(graph, (NODES,),
additional_message="when aggregating from nodes.")
num_graphs = utils_tf.get_num_graphs(graph) # 返回图中图的个数
graph_index = tf.range(num_graphs) # 生成列表 tensor
indices = utils_tf.repeat(graph_index, graph.n_node, axis=0) # 每个图序号重复该图中边的总数 在一维展开
return self._reducer(graph.nodes, indices, num_graphs) # 将每个图的对应所有点信息相加 返回的第0维度是图的个数
class _EdgesToNodesAggregator(_base.AbstractModule): # 对于将边的信息集成起来 流动到结点中
"""Agregates sent or received edges into the corresponding nodes."""
def __init__(self, reducer, use_sent_edges=False,
name="edges_to_nodes_aggregator"):
super(_EdgesToNodesAggregator, self).__init__(name=name)
self._reducer = reducer # 集成方法 注意自定义的函数需要满足一定的参数限制
self._use_sent_edges = use_sent_edges # 是否使用以该结点为发送点的边
def _build(self, graph):
_validate_graph(graph, (EDGES, SENDERS, RECEIVERS,), # 验证各参数是否满足相互依赖关系
additional_message="when aggregating from edges.")
# If the number of nodes are known at graph construction time (based on the
# shape) then use that value to make the model compatible with XLA/TPU.
if graph.nodes is not None and graph.nodes.shape.as_list()[0] is not None:
num_nodes = graph.nodes.shape.as_list()[0] # 所有图的结点总数
else:
num_nodes = tf.reduce_sum(graph.n_node) # 所有图的结点总数
indices = graph.senders if self._use_sent_edges else graph.receivers # 判断使用接收结点序号还是发送结点序号
return self._reducer(graph.edges, indices, num_nodes) # 将每个作为特定身份的结点的边抽取出来 信息求和 返回的第0维度长度是总的结点数
class SentEdgesToNodesAggregator(_EdgesToNodesAggregator): # 以发送结点为索引
"""Agregates sent edges into the corresponding sender nodes."""
def __init__(self, reducer, name="sent_edges_to_nodes_aggregator"):
"""Constructor.
The reducer is used for combining per-edge features (one set of edge
feature vectors per node) to give per-node features (one feature
vector per node). The reducer should take a `Tensor` of edge features, a
`Tensor` of segment indices, and a number of nodes. It should be invariant
under permutation of edge features within each segment.
Examples of compatible reducers are:
* tf.math.unsorted_segment_sum
* tf.math.unsorted_segment_mean
* tf.math.unsorted_segment_prod
* unsorted_segment_min_or_zero
* unsorted_segment_max_or_zero
Args:
reducer: A function for reducing sets of per-edge features to individual
per-node features.
name: The module name.
"""
super(SentEdgesToNodesAggregator, self).__init__(
use_sent_edges=True,
reducer=reducer,
name=name)
class ReceivedEdgesToNodesAggregator(_EdgesToNodesAggregator): # 以接收结点为索引
"""Agregates received edges into the corresponding receiver nodes."""
def __init__(self, reducer, name="received_edges_to_nodes_aggregator"):
"""Constructor.
The reducer is used for combining per-edge features (one set of edge
feature vectors per node) to give per-node features (one feature
vector per node). The reducer should take a `Tensor` of edge features, a
`Tensor` of segment indices, and a number of nodes. It should be invariant
under permutation of edge features within each segment.
Examples of compatible reducers are:
* tf.math.unsorted_segment_sum
* tf.math.unsorted_segment_mean
* tf.math.unsorted_segment_prod
* unsorted_segment_min_or_zero
* unsorted_segment_max_or_zero
Args:
reducer: A function for reducing sets of per-edge features to individual
per-node features.
name: The module name.
"""
super(ReceivedEdgesToNodesAggregator, self).__init__(
use_sent_edges=False, reducer=reducer, name=name)
def _unsorted_segment_reduction_or_zero(reducer, values, indices, num_groups):
"""Common code for unsorted_segment_{min,max}_or_zero (below)."""
reduced = reducer(values, indices, num_groups)
present_indices = tf.math.unsorted_segment_max(
tf.ones_like(indices, dtype=reduced.dtype), indices, num_groups)
present_indices = tf.clip_by_value(present_indices, 0, 1)
present_indices = tf.reshape(
present_indices, [num_groups] + [1] * (reduced.shape.ndims - 1))
reduced *= present_indices
return reduced
def unsorted_segment_min_or_zero(values, indices, num_groups,
name="unsorted_segment_min_or_zero"):
"""Aggregates information using elementwise min.
Segments with no elements are given a "min" of zero instead of the most
positive finite value possible (which is what `tf.math.unsorted_segment_min`
would do).
Args:
values: A `Tensor` of per-element features.
indices: A 1-D `Tensor` whose length is equal to `values`' first dimension.
num_groups: A `Tensor`.
name: (string, optional) A name for the operation.
Returns:
A `Tensor` of the same type as `values`.
"""
with tf.name_scope(name):
return _unsorted_segment_reduction_or_zero(
tf.math.unsorted_segment_min, values, indices, num_groups)
def unsorted_segment_max_or_zero(values, indices, num_groups,
name="unsorted_segment_max_or_zero"):
"""Aggregates information using elementwise max.
Segments with no elements are given a "max" of zero instead of the most
negative finite value possible (which is what `tf.math.unsorted_segment_max`
would do).
Args:
values: A `Tensor` of per-element features.
indices: A 1-D `Tensor` whose length is equal to `values`' first dimension.
num_groups: A `Tensor`.
name: (string, optional) A name for the operation.
Returns:
A `Tensor` of the same type as `values`.
"""
with tf.name_scope(name):
return _unsorted_segment_reduction_or_zero(
tf.math.unsorted_segment_max, values, indices, num_groups)
class EdgeBlock(_base.AbstractModule): # 边模型
"""Edge block.
A block that updates the features of each edge in a batch of graphs based on
(a subset of) the previous edge features, the features of the adjacent nodes,
and the global features of the corresponding graph.
See https://arxiv.org/abs/1806.01261 for more details.
"""
def __init__(self,
edge_model_fn, # sonnet 模型
use_edges=True, # 各种参数
use_receiver_nodes=True,
use_sender_nodes=True,
use_globals=True,
name="edge_block"):
"""Initializes the EdgeBlock module.
Args:
edge_model_fn: A callable that will be called in the variable scope of
this EdgeBlock and should return a Sonnet module (or equivalent
callable) to be used as the edge model. The returned module should take
a `Tensor` (of concatenated input features for each edge) and return a
`Tensor` (of output features for each edge). Typically, this module
would input and output `Tensor`s of rank 2, but it may also be input or
output larger ranks. See the `_build` method documentation for more
details on the acceptable inputs to this module in that case.
use_edges: (bool, default=True). Whether to condition on edge attributes.
use_receiver_nodes: (bool, default=True). Whether to condition on receiver
node attributes.
use_sender_nodes: (bool, default=True). Whether to condition on sender
node attributes.
use_globals: (bool, default=True). Whether to condition on global
attributes.
name: The module name.
Raises:
ValueError: When fields that are required are missing.
"""
super(EdgeBlock, self).__init__(name=name)
if not (use_edges or use_sender_nodes or use_receiver_nodes or use_globals):
raise ValueError("At least one of use_edges, use_sender_nodes, "
"use_receiver_nodes or use_globals must be True.")
# 赋值参数
self._use_edges = use_edges
self._use_receiver_nodes = use_receiver_nodes
self._use_sender_nodes = use_sender_nodes
self._use_globals = use_globals
with self._enter_variable_scope():
self._edge_model = edge_model_fn()
def _build(self, graph):
"""Connects the edge block.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, whose individual edges
features (if `use_edges` is `True`), individual nodes features (if
`use_receiver_nodes` or `use_sender_nodes` is `True`) and per graph
globals (if `use_globals` is `True`) should be concatenable on the last
axis.
Returns:
An output `graphs.GraphsTuple` with updated edges.
Raises:
ValueError: If `graph` does not have non-`None` receivers and senders, or
if `graph` has `None` fields incompatible with the selected `use_edges`,
`use_receiver_nodes`, `use_sender_nodes`, or `use_globals` options.
"""
# 验证图的各种参数是否满足依赖关系
_validate_graph(
graph, (SENDERS, RECEIVERS, N_EDGE), " when using an EdgeBlock")
edges_to_collect = []
if self._use_edges:
_validate_graph(graph, (EDGES,), "when use_edges == True")
edges_to_collect.append(graph.edges) # 存储graph.edges
if self._use_receiver_nodes:
edges_to_collect.append(broadcast_receiver_nodes_to_edges(graph)) # 存储每条边的接收结点的特征信息
if self._use_sender_nodes:
edges_to_collect.append(broadcast_sender_nodes_to_edges(graph)) # 存储每条边的发送结点的特征信息
if self._use_globals:
edges_to_collect.append(broadcast_globals_to_edges(graph)) # 存储全局变量到边的信息
collected_edges = tf.concat(edges_to_collect, axis=-1) # 因为第0维度长度为所有图中边的总数 所以在最后一个维度拼接
updated_edges = self._edge_model(collected_edges) # 使用边模型 生成对应的边参数
return graph.replace(edges=updated_edges) # 更新边的特征信息
class NodeBlock(_base.AbstractModule): # 点模型
"""Node block.
A block that updates the features of each node in batch of graphs based on
(a subset of) the previous node features, the aggregated features of the
adjacent edges, and the global features of the corresponding graph.
See https://arxiv.org/abs/1806.01261 for more details.
"""
def __init__(self,
node_model_fn, # sonnet 模型
use_received_edges=True, # 各种参数
use_sent_edges=False,
use_nodes=True,
use_globals=True,
received_edges_reducer=tf.math.unsorted_segment_sum, # 按相同序号抽取相加
sent_edges_reducer=tf.math.unsorted_segment_sum,
name="node_block"):
"""Initializes the NodeBlock module.
Args:
node_model_fn: A callable that will be called in the variable scope of
this NodeBlock and should return a Sonnet module (or equivalent
callable) to be used as the node model. The returned module should take
a `Tensor` (of concatenated input features for each node) and return a
`Tensor` (of output features for each node). Typically, this module
would input and output `Tensor`s of rank 2, but it may also be input or
output larger ranks. See the `_build` method documentation for more
details on the acceptable inputs to this module in that case.
use_received_edges: (bool, default=True) Whether to condition on
aggregated edges received by each node.
use_sent_edges: (bool, default=False) Whether to condition on aggregated
edges sent by each node.
use_nodes: (bool, default=True) Whether to condition on node attributes.
use_globals: (bool, default=True) Whether to condition on global
attributes.
received_edges_reducer: Reduction to be used when aggregating received
edges. This should be a callable whose signature matches
`tf.math.unsorted_segment_sum`.
sent_edges_reducer: Reduction to be used when aggregating sent edges.
This should be a callable whose signature matches
`tf.math.unsorted_segment_sum`.
name: The module name.
Raises:
ValueError: When fields that are required are missing.
"""
super(NodeBlock, self).__init__(name=name)
if not (use_nodes or use_sent_edges or use_received_edges or use_globals):
raise ValueError("At least one of use_received_edges, use_sent_edges, "
"use_nodes or use_globals must be True.")
# 赋值参数
self._use_received_edges = use_received_edges
self._use_sent_edges = use_sent_edges
self._use_nodes = use_nodes
self._use_globals = use_globals
with self._enter_variable_scope():
self._node_model = node_model_fn()
if self._use_received_edges: # 是否使用以该结点为接收点的边的信息
if received_edges_reducer is None:
raise ValueError(
"If `use_received_edges==True`, `received_edges_reducer` "
"should not be None.")
self._received_edges_aggregator = ReceivedEdgesToNodesAggregator( # 接收结点为索引
received_edges_reducer)
if self._use_sent_edges: # 是否使用以该结点为发送点的边的信息
if sent_edges_reducer is None:
raise ValueError(
"If `use_sent_edges==True`, `sent_edges_reducer` "
"should not be None.")
self._sent_edges_aggregator = SentEdgesToNodesAggregator( # 发送结点为索引
sent_edges_reducer)
def _build(self, graph):
"""Connects the node block.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, whose individual edges
features (if `use_received_edges` or `use_sent_edges` is `True`),
individual nodes features (if `use_nodes` is True) and per graph globals
(if `use_globals` is `True`) should be concatenable on the last axis.
Returns:
An output `graphs.GraphsTuple` with updated nodes.
"""
nodes_to_collect = []
if self._use_received_edges: # 是否使用接收结点为索引的边的信息
nodes_to_collect.append(self._received_edges_aggregator(graph))
if self._use_sent_edges: # 是否使用发送结点为索引的边的信息
nodes_to_collect.append(self._sent_edges_aggregator(graph))
if self._use_nodes: # 使用原本的结点信息
_validate_graph(graph, (NODES,), "when use_nodes == True")
nodes_to_collect.append(graph.nodes)
if self._use_globals: # 使用全局变量信息
nodes_to_collect.append(broadcast_globals_to_nodes(graph))
collected_nodes = tf.concat(nodes_to_collect, axis=-1) # 因为第0维度为图中结点的总数 在剩余的维度上拼接信息
updated_nodes = self._node_model(collected_nodes) # 得到新的结点特征信息
return graph.replace(nodes=updated_nodes) # 更新图中结点特征向量
class GlobalBlock(_base.AbstractModule): # 全局变量
"""Global block.
A block that updates the global features of each graph in a batch based on
(a subset of) the previous global features, the aggregated features of the
edges of the graph, and the aggregated features of the nodes of the graph.
See https://arxiv.org/abs/1806.01261 for more details.
"""
def __init__(self,
global_model_fn, # sonnet 模型
use_edges=True, # 各种参数 由字典加双星号解压
use_nodes=True,
use_globals=True,
nodes_reducer=tf.math.unsorted_segment_sum,
edges_reducer=tf.math.unsorted_segment_sum,
name="global_block"):
"""Initializes the GlobalBlock module.
Args:
global_model_fn: A callable that will be called in the variable scope of
this GlobalBlock and should return a Sonnet module (or equivalent
callable) to be used as the global model. The returned module should
take a `Tensor` (of concatenated input features) and return a `Tensor`
(the global output features). Typically, this module would input and
output `Tensor`s of rank 2, but it may also input or output larger
ranks. See the `_build` method documentation for more details on the
acceptable inputs to this module in that case.
use_edges: (bool, default=True) Whether to condition on aggregated edges.
use_nodes: (bool, default=True) Whether to condition on node attributes.
use_globals: (bool, default=True) Whether to condition on global
attributes.
nodes_reducer: Reduction to be used when aggregating nodes. This should
be a callable whose signature matches tf.math.unsorted_segment_sum.
edges_reducer: Reduction to be used when aggregating edges. This should
be a callable whose signature matches tf.math.unsorted_segment_sum.
name: The module name.
Raises:
ValueError: When fields that are required are missing.
"""
super(GlobalBlock, self).__init__(name=name)
if not (use_nodes or use_edges or use_globals):
raise ValueError("At least one of use_edges, "
"use_nodes or use_globals must be True.")
# 赋值参数
self._use_edges = use_edges
self._use_nodes = use_nodes
self._use_globals = use_globals
with self._enter_variable_scope():
self._global_model = global_model_fn() # 全局模型
if self._use_edges: # 使用图中的边的信息
if edges_reducer is None:
raise ValueError(
"If `use_edges==True`, `edges_reducer` should not be None.")
self._edges_aggregator = EdgesToGlobalsAggregator( # 集成对应图中的所有边的信息
edges_reducer)
if self._use_nodes: # 使用图中的结点的信息
if nodes_reducer is None:
raise ValueError(
"If `use_nodes==True`, `nodes_reducer` should not be None.")
self._nodes_aggregator = NodesToGlobalsAggregator( # 集成对应图中的所有点的信息
nodes_reducer)
def _build(self, graph):
"""Connects the global block.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, whose individual edges
(if `use_edges` is `True`), individual nodes (if `use_nodes` is True)
and per graph globals (if `use_globals` is `True`) should be
concatenable on the last axis.
Returns:
An output `graphs.GraphsTuple` with updated globals.
"""
globals_to_collect = []
if self._use_edges: # 使用边的信息
_validate_graph(graph, (EDGES,), "when use_edges == True")
globals_to_collect.append(self._edges_aggregator(graph))
if self._use_nodes: # 使用点的信息
_validate_graph(graph, (NODES,), "when use_nodes == True")
globals_to_collect.append(self._nodes_aggregator(graph))
if self._use_globals: # 使用之前的全局信息
_validate_graph(graph, (GLOBALS,), "when use_globals == True")
globals_to_collect.append(graph.globals)
collected_globals = tf.concat(globals_to_collect, axis=-1) # 第0维是图的数目 在后续维度拼接
updated_globals = self._global_model(collected_globals) # 得到新的全局变量
return graph.replace(globals=updated_globals) # 更新新的全局特征信息
(2)grphs.py
# Copyright 2018 The GraphNets Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""A class that defines graph-structured data.
The main purpose of the `GraphsTuple` is to represent multiple graphs with
different shapes and sizes in a way that supports batched processing.
This module first defines the string constants which are used to represent
graph(s) as tuples or dictionaries: `N_NODE, N_EDGE, NODES, EDGES, RECEIVERS,
SENDERS, GLOBALS`.
This representation could typically take the following form, for a batch of
`n_graphs` graphs stored in a `GraphsTuple` called graph:
- N_NODE: The number of nodes per graph. It is a vector of integers with shape
`[n_graphs]`, such that `graph.N_NODE[i]` is the number of nodes in the i-th
graph.
- N_EDGE: The number of edges per graph. It is a vector of integers with shape
`[n_graphs]`, such that `graph.N_NODE[i]` is the number of edges in the i-th
graph.
- NODES: The nodes features. It is either `None` (the graph has no node
features), or a vector of shape `[n_nodes] + node_shape`, where
`n_nodes = sum(graph.N_NODE)` is the total number of nodes in the batch of
graphs, and `node_shape` represents the shape of the features of each node.
The relative index of a node from the batched version can be recovered from
the `graph.N_NODE` property. For instance, the second node of the third
graph will have its features in the
`1 + graph.N_NODE[0] + graph.N_NODE[1]`-th slot of graph.NODES.
Observe that having a `None` value for this field does not mean that the
graphs have no nodes, only that they do not have features.
- EDGES: The edges features. It is either `None` (the graph has no edge
features), or a vector of shape `[n_edges] + edge_shape`, where
`n_edges = sum(graph.N_EDGE)` is the total number of edges in the batch of
graphs, and `edge_shape` represents the shape of the features of each edge.
The relative index of an edge from the batched version can be recovered from
the `graph.N_EDGE` property. For instance, the third edge of the third
graph will have its features in the `2 + graph.N_EDGE[0] + graph.N_EDGE[1]`-
th slot of graph.EDGES.
Observe that having a `None` value for this field does not necessarily mean
that the graph has no edges, only that they do not have features.
- RECEIVERS: The indices of the receiver nodes, for each edge. It is either
`None` (if the graph has no edges), or a vector of integers of shape
`[n_edges]`, such that `graph.RECEIVERS[i]` is the index of the node
receiving from the i-th edge.
Observe that the index is absolute (in other words, cumulative), i.e.
`graphs.RECEIVERS` take value in `[0, n_nodes]`. For instance, an edge
connecting the vertices with relative indices 2 and 3 in the second graph of
the batch would have a `RECEIVERS` value of `3 + graph.N_NODE[0]`.
If `graphs.RECEIVERS` is `None`, then `graphs.EDGES` and `graphs.SENDERS`
should also be `None`.
- SENDERS: The indices of the sender nodes, for each edge. It is either
`None` (if the graph has no edges), or a vector of integers of shape
`[n_edges]`, such that `graph.SENDERS[i]` is the index of the node
sending from the i-th edge.
Observe that the index is absolute, i.e. `graphs.RECEIVERS` take value in
`[0, n_nodes]`. For instance, an edge connecting the vertices with relative
indices 1 and 3 in the third graph of the batch would have a `SENDERS` value
of `1 + graph.N_NODE[0] + graph.N_NODE[1]`.
If `graphs.SENDERS` is `None`, then `graphs.EDGES` and `graphs.RECEIVERS`
should also be `None`.
- GLOBALS: The global features of the graph. It is either `None` (the graph
has no global features), or a vector of shape `[n_graphs] + global_shape`
representing graph level features.
The `utils_np` and `utils_tf` modules provide convenience methods to work with
graph that contain numpy and tensorflow data, respectively: conversion,
batching, unbatching, indexing, among others.
The `GraphsTuple` class, however, is not restricted to storing vectors, and can
be used to store attributes of graphs as well (for instance, types or shapes).
The only assertions it makes are that the `None` fields are compatible with the
definition of a graph given above, namely:
- the N_NODE and N_EDGE fields cannot be `None`;
- if RECEIVERS is None, then SENDERS must be `None` (and vice-versa);
- if RECEIVERS and SENDERS are `None`, then `EDGES` must be `None`.
Those assumptions are checked both upon initialization and when replacing a
field by calling the `replace` or `map` method.
"""
"""
GraphsTuple的主要目的是以支持批处理的方式表示具有不同形状和大小的多个图
该模块首先定义用于将图表示为元组或字典的字符串常量:
N_NODE,N_EDGE,NODES,EDGES,RECEIVERS,SENDERS,GLOBALS
N_NODE:每个图的节点数向量,它是形状为[n_graphs],因此graph.N_NODE[i]是第i个图中的节点数
N_EDGE:每个图的边数向量,它是形状为[n_graphs],因此graph.N_EDGE[i]是第i个图的边数
NODES:多维向量,第一维向量表示节点总数,可用N_NODE切分各图,剩余维度表示节点特征
例如,第三个图的第二个节点将在graph.NODES的第1+graph.N_NODE[0]+graph.N_NODE[1]个序号中具有其特征
EDGES:多维向量,第一维向量表示边总数,可用N_EDGE切分各图,剩余维度表示边特征
例如,第三个图的第三个边将在graph.EDGES的第2+graph.N_NODE[0]+graph.N_NODE[1]个序号中具有其特征
RECEIVERS:一维向量,长度为N_EDGE的长度,表示的是第i条边的接受者,如 2->3 其值就为 3
SENDERS:一维向量,长度为N_EDGE的长度,表示的是第i条边的发出者,如 2->3 其值就为 2
GLOBALS:全局特征变量,第一个维度表示图的数目,剩余维度表示其特征,可用 n_graphs 切分访问
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
NODES = "nodes"
EDGES = "edges"
RECEIVERS = "receivers"
SENDERS = "senders"
GLOBALS = "globals"
N_NODE = "n_node"
N_EDGE = "n_edge"
GRAPH_FEATURE_FIELDS = (NODES, EDGES, GLOBALS)
GRAPH_INDEX_FIELDS = (RECEIVERS, SENDERS)
GRAPH_DATA_FIELDS = (NODES, EDGES, RECEIVERS, SENDERS, GLOBALS)
GRAPH_NUMBER_FIELDS = (N_NODE, N_EDGE)
ALL_FIELDS = (NODES, EDGES, RECEIVERS, SENDERS, GLOBALS, N_NODE, N_EDGE)
class GraphsTuple(
collections.namedtuple("GraphsTuple",
GRAPH_DATA_FIELDS + GRAPH_NUMBER_FIELDS)): # 定义元组子类名 以及字典形式的键名
"""Default namedtuple describing `Graphs`s.
A children of `collections.namedtuple`s, which allows it to be directly input
and output from `tensorflow.Session.run()` calls.
An instance of this class can be constructed as
```
GraphsTuple(nodes=nodes,
edges=edges,
globals=globals,
receivers=receivers,
senders=senders,
n_node=n_node,
n_edge=n_edge)
```
where `nodes`, `edges`, `globals`, `receivers`, `senders`, `n_node` and
`n_edge` are arbitrary, but are typically numpy arrays, tensors, or `None`;
see module's documentation for a more detailed description of which fields
can be left `None`.
"""
def _validate_none_fields(self):
"""Asserts that the set of `None` fields in the instance is valid."""
if self.n_node is None:
raise ValueError("Field `n_node` cannot be None")
if self.n_edge is None:
raise ValueError("Field `n_edge` cannot be None")
if self.receivers is None and self.senders is not None:
raise ValueError(
"Field `senders` must be None as field `receivers` is None")
if self.senders is None and self.receivers is not None:
raise ValueError(
"Field `receivers` must be None as field `senders` is None")
if self.receivers is None and self.edges is not None:
raise ValueError(
"Field `edges` must be None as field `receivers` and `senders` are "
"None")
def __init__(self, *args, **kwargs):
del args, kwargs
# The fields of a `namedtuple` are filled in the `__new__` method.
# `__init__` does not accept parameters.
super(GraphsTuple, self).__init__()
self._validate_none_fields()
def replace(self, **kwargs):
output = self._replace(**kwargs) # 返回一个新的实例
output._validate_none_fields() # pylint: disable=protected-access 验证返回的新实例是否满足要求
return output
def map(self, field_fn, fields=GRAPH_FEATURE_FIELDS): # 对每个键应用函数
"""Applies `field_fn` to the fields `fields` of the instance.
`field_fn` is applied exactly once per field in `fields`. The result must
satisfy the `GraphsTuple` requirement w.r.t. `None` fields, i.e. the
`SENDERS` cannot be `None` if the `EDGES` or `RECEIVERS` are not `None`,
etc.
Args:
field_fn: A callable that take a single argument.
fields: (iterable of `str`). An iterable of the fields to apply
`field_fn` to.
Returns:
A copy of the instance, with the fields in `fields` replaced by the result
of applying `field_fn` to them.
"""
return self.replace(**{k: field_fn(getattr(self, k)) for k in fields}) # getattr(self, k) 获取的是键值对中的值, k表示键
(3)modules.py
# Copyright 2018 The GraphNets Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Common Graph Network architectures.
The modules in this files are Sonnet modules that:
- take a `graphs.GraphsTuple` containing `Tensor`s as input, with possibly
`None` fields (depending on the module);
- return a `graphs.GraphsTuple` with updated values for some fields
(depending on the module).
The provided modules are:
- `GraphNetwork`: a general purpose Graph Network composed of configurable
`EdgeBlock`, `NodeBlock` and `GlobalBlock` from `blocks.py`;
- `GraphIndependent`: a Graph Network producing updated edges (resp. nodes,
globals) based on the input's edges (resp. nodes, globals) only;
- `InteractionNetwork` (from https://arxiv.org/abs/1612.00222): a
network propagating information on the edges and nodes of a graph;
- RelationNetwork (from https://arxiv.org/abs/1706.01427): a network
updating the global property based on the relation between the input's
nodes properties;
- DeepSets (from https://arxiv.org/abs/1703.06114): a network that operates on
sets (graphs without edges);
- CommNet (from https://arxiv.org/abs/1605.07736 and
https://arxiv.org/abs/1706.06122): a network updating nodes based on their
previous features and the features of the adjacent nodes.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from graph_nets import _base
from graph_nets import blocks
import tensorflow as tf
_DEFAULT_EDGE_BLOCK_OPT = { # 边模块选项
"use_edges": True,
"use_receiver_nodes": True,
"use_sender_nodes": True,
"use_globals": True,
}
_DEFAULT_NODE_BLOCK_OPT = { # 点模块选项
"use_received_edges": True,
"use_sent_edges": False,
"use_nodes": True,
"use_globals": True,
}
_DEFAULT_GLOBAL_BLOCK_OPT = { # 全局变量选项
"use_edges": True,
"use_nodes": True,
"use_globals": True,
}
class InteractionNetwork(_base.AbstractModule):
"""Implementation of an Interaction Network.
An interaction networks computes interactions on the edges based on the
previous edges features, and on the features of the nodes sending into those
edges. It then updates the nodes based on the incomming updated edges.
See https://arxiv.org/abs/1612.00222 for more details.
This model does not update the graph globals, and they are allowed to be
`None`.
"""
def __init__(self,
edge_model_fn,
node_model_fn,
reducer=tf.math.unsorted_segment_sum,
name="interaction_network"):
"""Initializes the InteractionNetwork module.
Args:
edge_model_fn: A callable that will be passed to `EdgeBlock` to perform
per-edge computations. The callable must return a Sonnet module (or
equivalent; see `blocks.EdgeBlock` for details), and the shape of the
output of this module must match the one of the input nodes, but for the
first and last axis.
node_model_fn: A callable that will be passed to `NodeBlock` to perform
per-node computations. The callable must return a Sonnet module (or
equivalent; see `blocks.NodeBlock` for details).
reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to
tf.math.unsorted_segment_sum.
name: The module name.
"""
super(InteractionNetwork, self).__init__(name=name)
with self._enter_variable_scope():
self._edge_block = blocks.EdgeBlock(
edge_model_fn=edge_model_fn, use_globals=False)
self._node_block = blocks.NodeBlock(
node_model_fn=node_model_fn,
use_sent_edges=False,
use_globals=False,
received_edges_reducer=reducer)
def _build(self, graph):
"""Connects the InterationNetwork.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s. `graph.globals` can be
`None`. The features of each node and edge of `graph` must be
concatenable on the last axis (i.e., the shapes of `graph.nodes` and
`graph.edges` must match but for their first and last axis).
Returns:
An output `graphs.GraphsTuple` with updated edges and nodes.
Raises:
ValueError: If any of `graph.nodes`, `graph.edges`, `graph.receivers` or
`graph.senders` is `None`.
"""
return self._node_block(self._edge_block(graph))
class RelationNetwork(_base.AbstractModule):
"""Implementation of a Relation Network.
See https://arxiv.org/abs/1706.01427 for more details.
The global and edges features of the input graph are not used, and are
allowed to be `None` (the receivers and senders properties must be present).
The output graph has updated, non-`None`, globals.
"""
def __init__(self,
edge_model_fn,
global_model_fn,
reducer=tf.math.unsorted_segment_sum,
name="relation_network"):
"""Initializes the RelationNetwork module.
Args:
edge_model_fn: A callable that will be passed to EdgeBlock to perform
per-edge computations. The callable must return a Sonnet module (or
equivalent; see EdgeBlock for details).
global_model_fn: A callable that will be passed to GlobalBlock to perform
per-global computations. The callable must return a Sonnet module (or
equivalent; see GlobalBlock for details).
reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to
tf.math.unsorted_segment_sum.
name: The module name.
"""
super(RelationNetwork, self).__init__(name=name)
with self._enter_variable_scope():
self._edge_block = blocks.EdgeBlock(
edge_model_fn=edge_model_fn,
use_edges=False,
use_receiver_nodes=True,
use_sender_nodes=True,
use_globals=False)
self._global_block = blocks.GlobalBlock(
global_model_fn=global_model_fn,
use_edges=True,
use_nodes=False,
use_globals=False,
edges_reducer=reducer)
def _build(self, graph):
"""Connects the RelationNetwork.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, except for the edges
and global properties which may be `None`.
Returns:
A `graphs.GraphsTuple` with updated globals.
Raises:
ValueError: If any of `graph.nodes`, `graph.receivers` or `graph.senders`
is `None`.
"""
output_graph = self._global_block(self._edge_block(graph))
return graph.replace(globals=output_graph.globals)
def _make_default_edge_block_opt(edge_block_opt): # 返回通过边模型时所用到的信息选项
"""Default options to be used in the EdgeBlock of a generic GraphNetwork."""
edge_block_opt = dict(edge_block_opt.items()) if edge_block_opt else {} # 为空则返回空字典
for k, v in _DEFAULT_EDGE_BLOCK_OPT.items(): # 对于默认的选项
edge_block_opt[k] = edge_block_opt.get(k, v) # 如果自己定义的键不存在默认的选项中,则使用默认选项
return edge_block_opt
def _make_default_node_block_opt(node_block_opt, default_reducer): # 返回通过点模型时所用到的信息选项
"""Default options to be used in the NodeBlock of a generic GraphNetwork."""
node_block_opt = dict(node_block_opt.items()) if node_block_opt else {} # 为空则返回空字典
for k, v in _DEFAULT_NODE_BLOCK_OPT.items(): # 对于默认的选项
node_block_opt[k] = node_block_opt.get(k, v) # 如果自己定义的键不存在默认的选项中,则使用默认选项
for key in ["received_edges_reducer", "sent_edges_reducer"]: # 此处需要额外设置
node_block_opt[key] = node_block_opt.get(key, default_reducer) # 指定集成信息的方法,不存在则使用默认的方法
return node_block_opt
def _make_default_global_block_opt(global_block_opt, default_reducer): # 返回通过点模型时所用到的信息选项
"""Default options to be used in the GlobalBlock of a generic GraphNetwork."""
global_block_opt = dict(global_block_opt.items()) if global_block_opt else {} # 为空则返回空字典
for k, v in _DEFAULT_GLOBAL_BLOCK_OPT.items(): # 对于默认的选项
global_block_opt[k] = global_block_opt.get(k, v) # 如果自己定义的键不存在默认的选项中,则使用默认选项
for key in ["edges_reducer", "nodes_reducer"]: # 此处需要额外设置
global_block_opt[key] = global_block_opt.get(key, default_reducer) # 指定集成信息的方法,不存在则使用默认的方法
return global_block_opt
class GraphNetwork(_base.AbstractModule): # 各种信息均用到
"""Implementation of a Graph Network.
See https://arxiv.org/abs/1806.01261 for more details.
"""
def __init__(self,
edge_model_fn, # 边模型
node_model_fn, # 点模型
global_model_fn, # 全局信息模型
reducer=tf.math.unsorted_segment_sum, # 对于信息的集成
edge_block_opt=None, # 经过边模型使用哪些信息
node_block_opt=None, # 经过点模型时使用哪些信息
global_block_opt=None, # 全局变量模型使用哪些信息
name="graph_network"):
"""Initializes the GraphNetwork module.
Args:
edge_model_fn: A callable that will be passed to EdgeBlock to perform
per-edge computations. The callable must return a Sonnet module (or
equivalent; see EdgeBlock for details).
node_model_fn: A callable that will be passed to NodeBlock to perform
per-node computations. The callable must return a Sonnet module (or
equivalent; see NodeBlock for details).
global_model_fn: A callable that will be passed to GlobalBlock to perform
per-global computations. The callable must return a Sonnet module (or
equivalent; see GlobalBlock for details).
reducer: Reducer to be used by NodeBlock and GlobalBlock to aggregate
nodes and edges. Defaults to tf.math.unsorted_segment_sum. This will be
overridden by the reducers specified in `node_block_opt` and
`global_block_opt`, if any.
edge_block_opt: Additional options to be passed to the EdgeBlock. Can
contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`,
`use_globals`. By default, these are all True.
node_block_opt: Additional options to be passed to the NodeBlock. Can
contain the keys `use_received_edges`, `use_nodes`, `use_globals` (all
set to True by default), `use_sent_edges` (defaults to False), and
`received_edges_reducer`, `sent_edges_reducer` (default to `reducer`).
global_block_opt: Additional options to be passed to the GlobalBlock. Can
contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to
True by default), and `edges_reducer`, `nodes_reducer` (defaults to
`reducer`).
name: The module name.
"""
super(GraphNetwork, self).__init__(name=name)
# 设置每个模型所需要用到的信息
edge_block_opt = _make_default_edge_block_opt(edge_block_opt)
node_block_opt = _make_default_node_block_opt(node_block_opt, reducer)
global_block_opt = _make_default_global_block_opt(global_block_opt, reducer)
# 设置自身点边全局模型
with self._enter_variable_scope(): # ** 表示将其解开为单独的值访问
self._edge_block = blocks.EdgeBlock( # 构建边模型
edge_model_fn=edge_model_fn, **edge_block_opt)
self._node_block = blocks.NodeBlock( # 构建点模型
node_model_fn=node_model_fn, **node_block_opt)
self._global_block = blocks.GlobalBlock( # 构建全局模型
global_model_fn=global_model_fn, **global_block_opt)
def _build(self, graph):
"""Connects the GraphNetwork.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s. Depending on the block
options, `graph` may contain `None` fields; but with the default
configuration, no `None` field is allowed. Moreover, when using the
default configuration, the features of each nodes, edges and globals of
`graph` should be concatenable on the last dimension.
Returns:
An output `graphs.GraphsTuple` with updated edges, nodes and globals.
"""
return self._global_block(self._node_block(self._edge_block(graph)))
class GraphIndependent(_base.AbstractModule):
"""A graph block that applies models to the graph elements independently.
The inputs and outputs are graphs. The corresponding models are applied to
each element of the graph (edges, nodes and globals) in parallel and
independently of the other elements. It can be used to encode or
decode the elements of a graph.
"""
def __init__(self,
edge_model_fn=None,
node_model_fn=None,
global_model_fn=None,
name="graph_independent"):
"""Initializes the GraphIndependent module.
Args:
edge_model_fn: A callable that returns an edge model function. The
callable must return a Sonnet module (or equivalent). If passed `None`,
will pass through inputs (the default).
node_model_fn: A callable that returns a node model function. The callable
must return a Sonnet module (or equivalent). If passed `None`, will pass
through inputs (the default).
global_model_fn: A callable that returns a global model function. The
callable must return a Sonnet module (or equivalent). If passed `None`,
will pass through inputs (the default).
name: The module name.
"""
super(GraphIndependent, self).__init__(name=name)
with self._enter_variable_scope():
# The use of snt.Module below is to ensure the ops and variables that
# result from the edge/node/global_model_fns are scoped analogous to how
# the Edge/Node/GlobalBlock classes do.
if edge_model_fn is None:
self._edge_model = lambda x: x
else:
self._edge_model = _base.WrappedModelFnModule(
edge_model_fn, name="edge_model")
if node_model_fn is None:
self._node_model = lambda x: x
else:
self._node_model = _base.WrappedModelFnModule(
node_model_fn, name="node_model")
if global_model_fn is None:
self._global_model = lambda x: x
else:
self._global_model = _base.WrappedModelFnModule(
global_model_fn, name="global_model")
def _build(self, graph):
"""Connects the GraphIndependent.
Args:
graph: A `graphs.GraphsTuple` containing non-`None` edges, nodes and
globals.
Returns:
An output `graphs.GraphsTuple` with updated edges, nodes and globals.
"""
return graph.replace(
edges=self._edge_model(graph.edges),
nodes=self._node_model(graph.nodes),
globals=self._global_model(graph.globals))
class DeepSets(_base.AbstractModule):
"""DeepSets module.
Implementation for the model described in https://arxiv.org/abs/1703.06114
(M. Zaheer, S. Kottur, S. Ravanbakhsh, B. Poczos, R. Salakhutdinov, A. Smola).
See also PointNet (https://arxiv.org/abs/1612.00593, C. Qi, H. Su, K. Mo,
L. J. Guibas) for a related model.
This module operates on sets, which can be thought of as graphs without
edges. The nodes features are first updated based on their value and the
globals features, and new globals features are then computed based on the
updated nodes features.
Note that in the original model, only the globals are updated in the returned
graph, while this implementation also returns updated nodes.
The original model can be reproduced by writing:
```
deep_sets = DeepSets()
output = deep_sets(input)
output = input.replace(globals=output.globals)
```
This module does not use the edges data or the information contained in the
receivers or senders; the output graph has the same value in those fields as
the input graph. Those fields can also have `None` values in the input
`graphs.GraphsTuple`.
"""
def __init__(self,
node_model_fn,
global_model_fn,
reducer=tf.math.unsorted_segment_sum,
name="deep_sets"):
"""Initializes the DeepSets module.
Args:
node_model_fn: A callable to be passed to NodeBlock. The callable must
return a Sonnet module (or equivalent; see NodeBlock for details). The
shape of this module's output must equal the shape of the input graph's
global features, but for the first and last axis.
global_model_fn: A callable to be passed to GlobalBlock. The callable must
return a Sonnet module (or equivalent; see GlobalBlock for details).
reducer: Reduction to be used when aggregating the nodes in the globals.
This should be a callable whose signature matches
tf.math.unsorted_segment_sum.
name: The module name.
"""
super(DeepSets, self).__init__(name=name)
with self._enter_variable_scope():
self._node_block = blocks.NodeBlock(
node_model_fn=node_model_fn,
use_received_edges=False,
use_sent_edges=False,
use_nodes=True,
use_globals=True)
self._global_block = blocks.GlobalBlock(
global_model_fn=global_model_fn,
use_edges=False,
use_nodes=True,
use_globals=False,
nodes_reducer=reducer)
def _build(self, graph):
"""Connects the DeepSets network.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, whose edges, senders
or receivers properties may be `None`. The features of every node and
global of `graph` should be concatenable on the last axis (i.e. the
shapes of `graph.nodes` and `graph.globals` must match but for their
first and last axis).
Returns:
An output `graphs.GraphsTuple` with updated globals.
"""
return self._global_block(self._node_block(graph))
class CommNet(_base.AbstractModule):
"""CommNet module.
Implementation for the model originally described in
https://arxiv.org/abs/1605.07736 (S. Sukhbaatar, A. Szlam, R. Fergus), in the
version presented in https://arxiv.org/abs/1706.06122 (Y. Hoshen).
This module internally creates edge features based on the features from the
nodes sending to that edge, and independently learns an embedding for each
node. It then uses these edges and nodes features to compute updated node
features.
This module does not use the global nor the edges features of the input, but
uses its receivers and senders information. The output graph has the same
value in edge and global fields as the input graph. The edge and global
features fields may have a `None` value in the input `gn_graphs.GraphsTuple`.
"""
def __init__(self,
edge_model_fn,
node_encoder_model_fn,
node_model_fn,
reducer=tf.math.unsorted_segment_sum,
name="comm_net"):
"""Initializes the CommNet module.
Args:
edge_model_fn: A callable to be passed to EdgeBlock. The callable must
return a Sonnet module (or equivalent; see EdgeBlock for details).
node_encoder_model_fn: A callable to be passed to the NodeBlock
responsible for the first encoding of the nodes. The callable must
return a Sonnet module (or equivalent; see NodeBlock for details). The
shape of this module's output should match the shape of the module built
by `edge_model_fn`, but for the first and last dimension.
node_model_fn: A callable to be passed to NodeBlock. The callable must
return a Sonnet module (or equivalent; see NodeBlock for details).
reducer: Reduction to be used when aggregating the edges in the nodes.
This should be a callable whose signature matches
tf.math.unsorted_segment_sum.
name: The module name.
"""
super(CommNet, self).__init__(name=name)
with self._enter_variable_scope():
# Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122
self._edge_block = blocks.EdgeBlock(
edge_model_fn=edge_model_fn,
use_edges=False,
use_receiver_nodes=False,
use_sender_nodes=True,
use_globals=False)
# Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122
self._node_encoder_block = blocks.NodeBlock(
node_model_fn=node_encoder_model_fn,
use_received_edges=False,
use_sent_edges=False,
use_nodes=True,
use_globals=False,
received_edges_reducer=reducer,
name="node_encoder_block")
# Computes $\Theta(..)$ in Eq.(2) of 1706.06122
self._node_block = blocks.NodeBlock(
node_model_fn=node_model_fn,
use_received_edges=True,
use_sent_edges=False,
use_nodes=True,
use_globals=False,
received_edges_reducer=reducer)
def _build(self, graph):
"""Connects the CommNet network.
Args:
graph: A `graphs.GraphsTuple` containing `Tensor`s, with non-`None` nodes,
receivers and senders.
Returns:
An output `graphs.GraphsTuple` with updated nodes.
Raises:
ValueError: if any of `graph.nodes`, `graph.receivers` or `graph.senders`
is `None`.
"""
node_input = self._node_encoder_block(self._edge_block(graph))
return graph.replace(nodes=self._node_block(node_input).nodes)
def _unsorted_segment_softmax(data,
segment_ids,
num_segments,
name="unsorted_segment_softmax"):
"""Performs an elementwise softmax operation along segments of a tensor.
The input parameters are analogous to `tf.math.unsorted_segment_sum`. It
produces an output of the same shape as the input data, after performing an
elementwise sofmax operation between all of the rows with common segment id.
Args:
data: A tensor with at least one dimension.
segment_ids: A tensor of indices segmenting `data` across the first
dimension.
num_segments: A scalar tensor indicating the number of segments. It should
be at least `max(segment_ids) + 1`.
name: A name for the operation (optional).
Returns:
A tensor with the same shape as `data` after applying the softmax operation.
"""
with tf.name_scope(name):
segment_maxes = tf.math.unsorted_segment_max(data, segment_ids,
num_segments)
maxes = tf.gather(segment_maxes, segment_ids)
# Possibly refactor to `tf.stop_gradient(maxes)` for better performance.
data -= maxes
exp_data = tf.exp(data)
segment_sum_exp_data = tf.math.unsorted_segment_sum(exp_data, segment_ids,
num_segments)
sum_exp_data = tf.gather(segment_sum_exp_data, segment_ids)
return exp_data / sum_exp_data
def _received_edges_normalizer(graph,
normalizer,
name="received_edges_normalizer"):
"""Performs elementwise normalization for all received edges by a given node.
Args:
graph: A graph containing edge information.
normalizer: A normalizer function following the signature of
`modules._unsorted_segment_softmax`.
name: A name for the operation (optional).
Returns:
A tensor with the resulting normalized edges.
"""
with tf.name_scope(name):
return normalizer(
data=graph.edges,
segment_ids=graph.receivers,
num_segments=tf.reduce_sum(graph.n_node))
class SelfAttention(_base.AbstractModule):
"""Multi-head self-attention module.
The module is based on the following three papers:
* A simple neural network module for relational reasoning (RNs):
https://arxiv.org/abs/1706.01427
* Non-local Neural Networks: https://arxiv.org/abs/1711.07971.
* Attention Is All You Need (AIAYN): https://arxiv.org/abs/1706.03762.
The input to the modules consists of a graph containing values for each node
and connectivity between them, a tensor containing keys for each node
and a tensor containing queries for each node.
The self-attention step consist of updating the node values, with each new
node value computed in a two step process:
- Computing the attention weights between each node and all of its senders
nodes, by calculating sum(sender_key*receiver_query) and using the softmax
operation on all attention weights for each node.
- For each receiver node, compute the new node value as the weighted average
of the values of the sender nodes, according to the attention weights.
- Nodes with no received edges, get an updated value of 0.
Values, keys and queries contain a "head" axis to compute independent
self-attention for each of the heads.
"""
def __init__(self, name="self_attention"):
"""Inits the module.
Args:
name: The module name.
"""
super(SelfAttention, self).__init__(name=name)
self._normalizer = _unsorted_segment_softmax
def _build(self, node_values, node_keys, node_queries, attention_graph):
"""Connects the multi-head self-attention module.
The self-attention is only computed according to the connectivity of the
input graphs, with receiver nodes attending to sender nodes.
Args:
node_values: Tensor containing the values associated to each of the nodes.
The expected shape is [total_num_nodes, num_heads, key_size].
node_keys: Tensor containing the key associated to each of the nodes. The
expected shape is [total_num_nodes, num_heads, key_size].
node_queries: Tensor containing the query associated to each of the nodes.
The expected shape is [total_num_nodes, num_heads, query_size]. The
query size must be equal to the key size.
attention_graph: Graph containing connectivity information between nodes
via the senders and receivers fields. Node A will only attempt to attend
to Node B if `attention_graph` contains an edge sent by Node A and
received by Node B.
Returns:
An output `graphs.GraphsTuple` with updated nodes containing the
aggregated attended value for each of the nodes with shape
[total_num_nodes, num_heads, value_size].
Raises:
ValueError: if the input graph does not have edges.
"""
# Sender nodes put their keys and values in the edges.
# [total_num_edges, num_heads, query_size]
sender_keys = blocks.broadcast_sender_nodes_to_edges(
attention_graph.replace(nodes=node_keys))
# [total_num_edges, num_heads, value_size]
sender_values = blocks.broadcast_sender_nodes_to_edges(
attention_graph.replace(nodes=node_values))
# Receiver nodes put their queries in the edges.
# [total_num_edges, num_heads, key_size]
receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
attention_graph.replace(nodes=node_queries))
# Attention weight for each edge.
# [total_num_edges, num_heads]
attention_weights_logits = tf.reduce_sum(
sender_keys * receiver_queries, axis=-1)
normalized_attention_weights = _received_edges_normalizer(
attention_graph.replace(edges=attention_weights_logits),
normalizer=self._normalizer)
# Attending to sender values according to the weights.
# [total_num_edges, num_heads, embedding_size]
attented_edges = sender_values * normalized_attention_weights[..., None]
# Summing all of the attended values from each node.
# [total_num_nodes, num_heads, embedding_size]
received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
reducer=tf.math.unsorted_segment_sum)
aggregated_attended_values = received_edges_aggregator(
attention_graph.replace(edges=attented_edges))
return attention_graph.replace(nodes=aggregated_attended_values)