基于门控序列的图神经网络
论文:Li Y , Tarlow D , Brockschmidt M , et al. Gated Graph Sequence Neural Networks[J]. Computer Science, 2015.
代码:gated-graph-neural-network-samples
1 图神经网络
1.1 公式化描述
不整那些虚的,直接上公式
上式中h
是第t
个时间步的点嵌入(node embedding),也可以理解为是一个点的特征矩阵,维度是(node, feature),传入的参数分别是node label,相邻点的node label,相邻边的edge label,上一个时间步的h
上式关键就是考虑到边是有向的,因此将图计算分成了入边、出边两个部分
这是图神经网路的核心部分(graph),也就是说用图A与上一个时间步的点嵌入做矩阵乘法,具体的维度变化是:(node, node) x (node, feature) => (node, feature)
1.2 图形描述
(a)是假设的一个图,由于图的方向性,可以画出一个如图©所示的特殊邻接矩阵
(b)则直观描述了图神经网络在做什么,神经元的连接和图边连接是直接对应的,且不同种类的边是区分开的
2 基于门控序列的图神经网络
2.1 节点标注(初始化h)
以节点的到达关系为例,假设我们的任务是判断节点s
是否可以到达节点t
,则设x_s=[1,0].T
,,x_t=[0,1].T
,其他节点是[0,0].T
,将这些节点标注连起来,再padding成feature_dim
维度,具体可见公式1
2.2 门控
实际上就是一个GRU的变体
值得注意的是上式中的
(1)也就是标注过程,
(2)是图神经网络连接
(3-6)就是中规中矩的GRU
2.3 传播
大概说每一个时间步,都要计算一个输出F_o
,也要计算一个送入下一个时间步的输入F_x
3 代码分析
只看论文确实让人觉得玄学,特别是annotation部分,很迷
结合代码来看就好很多,这里例举的是@JamesChuanggg的pytorch实现ggnn.pytorch,这个实现的代码相比于官方版本来说,容易读很多
3.1 annotation
annotation = np.zeros([n_nodes, n_annotation_dim])
annotation[target[1]-1][0] = 1
核心实现就是上面这个,除了表达到达关系部分用了1,其他padding成了0
3.2 每一个时间步的实现
class Propogator(nn.Module):
"""
Gated Propogator for GGNN
Using LSTM gating mechanism
"""
def __init__(self, state_dim, n_node, n_edge_types):
## 初始化参照源代码
def forward(self, state_in, state_out, state_cur, A):
# 入边向量和出边向量
A_in = A[:, :, :self.n_node*self.n_edge_types]
A_out = A[:, :, self.n_node*self.n_edge_types:]
# 入边向量和出边向量分别和图做计算
a_in = torch.bmm(A_in, state_in)
a_out = torch.bmm(A_out, state_out)
a = torch.cat((a_in, a_out, state_cur), 2)
# 类GRU部分
r = self.reset_gate(a)
z = self.update_gate(a)
joined_input = torch.cat((a_in, a_out, r * state_cur), 2)
h_hat = self.tansform(joined_input)
output = (1 - z) * state_cur + z * h_hat
return output
3.3 网络结构
class GGNN(nn.Module):
"""
Gated Graph Sequence Neural Networks (GGNN)
Mode: SelectNode
Implementation based on https://arxiv.org/abs/1511.05493
"""
def __init__(self, opt):
# 初始化参考源代码
def forward(self, prop_state, annotation, A):
# prop_state:论文中的h
# annotation:节点标注
# A:图
for i_step in range(self.n_steps):
# 对于每一个时间步循环
in_states = []
out_states = []
for i in range(self.n_edge_types):
# 对输入特征做两个分支的全连接,得到入边特征,和出边特征
# 每一种边都要计算一次
in_states.append(self.in_fcs[i](prop_state))
out_states.append(self.out_fcs[i](prop_state))
# 将所有种类的边得到的特征连接起来
in_states = torch.stack(in_states).transpose(0, 1).contiguous()
in_states = in_states.view(-1, self.n_node*self.n_edge_types, self.state_dim)
out_states = torch.stack(out_states).transpose(0, 1).contiguous()
out_states = out_states.view(-1, self.n_node*self.n_edge_types, self.state_dim)
# 用门控图模块更新h
prop_state = self.propogator(in_states, out_states, prop_state, A)
join_state = torch.cat((prop_state, annotation), 2)
output = self.out(join_state)
output = output.sum(2)
return output