本文为图神经网络的学习笔记,讲解基本 Graph Net。欢迎在评论区与我交流👏
教程
该教程的目的是通过以下例子对图网络库进行实践:
- 使用
graph_nets.utils_np
构造图数据结构graph_nets.graphs.GraphsTuple
- 使用
graph_nets.utils_tf
在 tensorflow 图中操作图数据结构 - 在
graph_nets.modules
中将图喂给图网络 tensorflow 模型 - 使用
graph_nets.blocks
提供的图网络构建块构造自定义的图网络模型
更多关于图网络的信息,见【论文】。
在 Colaboratory runtime 上安装图网络库:
install_graph_nets_library = "No" #@param ["Yes", "No"]
if install_graph_nets_library.lower() == "yes":
print("Installing Graph Nets library with:")
print(" $ pip install graph_nets\n")
print("Output message from command:\n")
!pip install graph_nets
else:
print("Skipping installation of Graph Nets library")
引入包:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from graph_nets import blocks
from graph_nets import graphs
from graph_nets import modules
from graph_nets import utils_np
from graph_nets import utils_tf
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import sonnet as snt
import tensorflow as tf
graphs.GraphsTuple
类
图网络库包含了在图结构数据上的模型,因此首先要理解图结构数据是怎样用代码表示的。
定义在 graph_nets/graphs.py
中的 graphs.GraphsTuple
类,表示一批一个或多个图(a batches of one or more graphs)。所有的图网络模型以 GraphsTuple
的实例作为输入,返回 GraphsTuple
实例作为输出。图为有向图(单向边),有属性(允许节点、边和图级别特征)、多图(多边可以连接任意两个节点,允许自环)。
GraphsTuple
的属性有:
n_node
(shape=[num_graphs]):批处理中每个图的节点数n_edge
(shape=[num_graphs]):批处理中每个图的边数globals
(shape=[num_graphs] + global_feature_dimensions):批处理中每个图的全局特征nodes
(shape=[total_num_nodes] + node_feature_dimensions):该批图中每个节点的节点特性edges
(shape=[total_num_edges] + edge_feature_dimensions):该批图中每个边的边特性senders
(shape=[total_num_edges]):nodes
中节点序号,表示edges
中每条有向边的源节点receivers
(shape=[total_num_edges]):nodes
中节点序号,表示edges
中每条有向边的目的节点
来自批处理不同图的节点和边用 nodes
和 edges
的第一个维度连接,并且可以分别使用 n_node
和 n_edge
字段进行划分。除了 n_*
字段外,所有字段可选。
GraphsTuple
实例的属性通常是 Numpy arrays 或 TensorFlow tensors。
该库包含了一些实用工具,可以使用这些类型的属性来操作图:
utils_np
(Numpy arrays)utils_tf
(TensorFlow tensors)
GraphsTuple
类中一个重要的方法是 GraphsTuple.replace
。和 collections.namedtuple._replace
类似(实际为其子类),这个方法创建了 GraphsTuple
的拷贝,引用了所有的原始属性,用提供的关键字参数对其中的一些值进行替换。
创建图
图中包含什么?
每个图有一个全局特征,一些点和边。图的点和边数可以不同,但全局、节点和边的属性向量必须在全图中相同。为创建 graphs.GraphsTuple
实例,我们可以定义一个 list
,元素是 dict
s,使用以下的键,键包括每个图的数据:
- “globals”:每个图有一个
float
值特征向量 - “Nodes”:每个图有一个节点集,有
float
值特征向量 - “globals”:每个图有一个边集,有
float
值特征向量 - “senders”:每个边连接的一个发送节点,用
int
值的节点序号表示 - “receivers”:每个边连接的一个接受节点,用
int
值的节点序号表示
下面的代码创建了虚拟图数据:
# Global features for graph 0.
globals_0 = [1., 2., 3.]
# Node features for graph 0.
nodes_0 = [[10., 20., 30.], # Node 0
[11., 21., 31.], # Node 1
[12., 22., 32.], # Node 2
[13., 23., 33.], # Node 3
[14., 24., 34.]] # Node 4
# Edge features for graph 0.
edges_0 = [[100., 200.], # Edge 0
[101., 201.], # Edge 1
[102., 202.], # Edge 2
[103., 203.], # Edge 3
[104., 204.], # Edge 4
[105., 205.]] # Edge 5
# The sender and receiver nodes associated with each edge for graph 0.
senders_0 = [0, # Index of the sender node for edge 0
1, # Index of the sender node for edge 1
1, # Index of the sender node for edge 2
2, # Index of the sender node for edge 3
2, # Index of the sender node for edge 4
3] # Index of the sender node for edge 5
receivers_0 = [1, # Index of the receiver node for edge 0
2, # Index of the receiver node for edge 1
3, # Index of the receiver node for edge 2
0, # Index of the receiver node for edge 3
3, # Index of the receiver node for edge 4
4] # Index of the receiver node for edge 5
# Global features for graph 1.
globals_1 = [1001., 1002., 1003.]
# Node features for graph 1.
nodes_1 = [[1010., 1020., 1030.], # Node 0
[1011., 1021., 1031.]] # Node 1
# Edge features for graph 1.
edges_1 = [[1100., 1200.], # Edge 0
[1101., 1201.], # Edge 1
[1102., 1202.], # Edge 2
[1103., 1203.]] # Edge 3
# The sender and receiver nodes associated with each edge for graph 1.
senders_1 = [0, # Index of the sender node for edge 0
0, # Index of the sender node for edge 1
1, # Index of the sender node for edge 2
1] # Index of the sender node for edge 3
receivers_1 = [0, # Index of the receiver node for edge 0
1, # Index of the receiver node for edge 1
0, # Index of the receiver node for edge 2
0] # Index of the receiver node for edge 3
data_dict_0 = {
"globals": globals_0,
"nodes": nodes_0,
"edges": edges_0,
"senders": senders_0,
"receivers": receivers_0
}
data_dict_1 = {
"globals": globals_1,
"nodes": nodes_1,
"edges": edges_1,
"senders": senders_1,
"receivers": receivers_1
}
将图表示为 graphs.GraphsTuple
utils_np
包括了名为 utils_np.data_dicts_to_graphs_tuple
的函数,传入上面指定键的 dict
s 的 list
,返回 GraphsTuple
表示图序列。
data_dicts_to_graphs_tuple
做了三件事:
- 将多个图的数据按最内层的轴(即批处理维度)连接在一起,使得图网络可以通过一个共享函数并行处理节点和边的属性
- 计算每个图的节点数和边数,并将它们分别存储在字段
n_node
和n_edge
中,长度等于图的数量。这用于跟踪哪些节点和边属于哪个图,以便稍后进行分割。这样图就可以跨节点和边传播图的全局属性 - 向发送节点和接受节点索引加一个整数偏移量,该偏移量对应于前面图中的节点数。这使得索引,在连接节点和边属性后,对应于对应图的节点和边
使用 utils_np.data_dicts_to_graphs_tuple
将图字典放入 GraphsTuple
:
data_dict_list = [data_dict_0, data_dict_1]
graphs_tuple = utils_np.data_dicts_to_graphs_tuple(data_dict_list)
使用 networkx
对图可视化
GraphsTuple
可转换为 networkx
图对象 list
,以便于可视化。
可视化前面定义的图:
graphs_nx = utils_np.graphs_tuple_to_networkxs(graphs_tuple