【笔记】GATED GRAPH SEQUENCE NEURAL NETWORKS

基于门控序列的图神经网络

论文:Li Y , Tarlow D , Brockschmidt M , et al. Gated Graph Sequence Neural Networks[J]. Computer Science, 2015.
代码:gated-graph-neural-network-samples

1 图神经网络

1.1 公式化描述

不整那些虚的,直接上公式
(1)
上式中h是第t个时间步的点嵌入(node embedding),也可以理解为是一个点的特征矩阵,维度是(node, feature),传入的参数分别是node label,相邻点的node label,相邻边的edge label,上一个时间步的h
在这里插入图片描述
上式关键就是考虑到边是有向的,因此将图计算分成了入边、出边两个部分
在这里插入图片描述
这是图神经网路的核心部分(graph),也就是说用图A与上一个时间步的点嵌入做矩阵乘法,具体的维度变化是:(node, node) x (node, feature) => (node, feature)

1.2 图形描述

在这里插入图片描述
(a)是假设的一个图,由于图的方向性,可以画出一个如图©所示的特殊邻接矩阵
(b)则直观描述了图神经网络在做什么,神经元的连接和图边连接是直接对应的,且不同种类的边是区分开的

2 基于门控序列的图神经网络

2.1 节点标注(初始化h)

以节点的到达关系为例,假设我们的任务是判断节点s是否可以到达节点t,则设x_s=[1,0].T,,x_t=[0,1].T,其他节点是[0,0].T,将这些节点标注连起来,再padding成feature_dim维度,具体可见公式1

2.2 门控

实际上就是一个GRU的变体
在这里插入图片描述
值得注意的是上式中的
(1)也就是标注过程,
(2)是图神经网络连接
(3-6)就是中规中矩的GRU

2.3 传播

在这里插入图片描述
大概说每一个时间步,都要计算一个输出F_o,也要计算一个送入下一个时间步的输入F_x

3 代码分析

只看论文确实让人觉得玄学,特别是annotation部分,很迷
结合代码来看就好很多,这里例举的是@JamesChuanggg的pytorch实现ggnn.pytorch,这个实现的代码相比于官方版本来说,容易读很多

3.1 annotation
annotation = np.zeros([n_nodes, n_annotation_dim])
annotation[target[1]-1][0] = 1

核心实现就是上面这个,除了表达到达关系部分用了1,其他padding成了0

3.2 每一个时间步的实现
class Propogator(nn.Module):
    """
    Gated Propogator for GGNN
    Using LSTM gating mechanism
    """
    def __init__(self, state_dim, n_node, n_edge_types):
        ## 初始化参照源代码

    def forward(self, state_in, state_out, state_cur, A):
        # 入边向量和出边向量
        A_in = A[:, :, :self.n_node*self.n_edge_types]
        A_out = A[:, :, self.n_node*self.n_edge_types:]

        # 入边向量和出边向量分别和图做计算
        a_in = torch.bmm(A_in, state_in)
        a_out = torch.bmm(A_out, state_out)
        a = torch.cat((a_in, a_out, state_cur), 2)

        # 类GRU部分
        r = self.reset_gate(a)
        z = self.update_gate(a)
        joined_input = torch.cat((a_in, a_out, r * state_cur), 2)
        h_hat = self.tansform(joined_input)

        output = (1 - z) * state_cur + z * h_hat

        return output
3.3 网络结构
class GGNN(nn.Module):
    """
    Gated Graph Sequence Neural Networks (GGNN)
    Mode: SelectNode
    Implementation based on https://arxiv.org/abs/1511.05493
    """
    def __init__(self, opt):
        # 初始化参考源代码
    def forward(self, prop_state, annotation, A):
        # prop_state:论文中的h
        # annotation:节点标注
        # A:图
        for i_step in range(self.n_steps):
            # 对于每一个时间步循环
            in_states = []
            out_states = []
            for i in range(self.n_edge_types):
                # 对输入特征做两个分支的全连接,得到入边特征,和出边特征
                # 每一种边都要计算一次
                in_states.append(self.in_fcs[i](prop_state))
                out_states.append(self.out_fcs[i](prop_state))
            # 将所有种类的边得到的特征连接起来
            in_states = torch.stack(in_states).transpose(0, 1).contiguous()
            in_states = in_states.view(-1, self.n_node*self.n_edge_types, self.state_dim)
            out_states = torch.stack(out_states).transpose(0, 1).contiguous()
            out_states = out_states.view(-1, self.n_node*self.n_edge_types, self.state_dim)

            # 用门控图模块更新h
            prop_state = self.propogator(in_states, out_states, prop_state, A)

        join_state = torch.cat((prop_state, annotation), 2)

        output = self.out(join_state)
        output = output.sum(2)

        return output
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值