【代码分析】Graph-U-Nets

7 篇文章 2 订阅
4 篇文章 0 订阅

Graph-U-Nets(二)代码分析

Graph U-Nets通过Pytorch进行实现,开源代码Graph U-Nets,原论文链接Graph U-Nets

直接对作者的代码进行解读

class GraphUnet(nn.Module):
    def __init__(self, ks, in_dim, out_dim, dim, act, drop_p):
    """
    :param ks: 表示pools层进行的节点采样率,数据类型为float型
    """
        super(GraphUnet, self).__init__()
        self.ks = ks
        # 创建底部的gcn
        self.bottom_gcn = GCN(dim, dim, act, drop_p)
        # 对应原文的encoder中的gcn
        self.down_gcns = nn.ModuleList()
        # 对应原文的decoder中的gcn
        self.up_gcns = nn.ModuleList()
        # 对应原文的encoder中的gPool
        self.pools = nn.ModuleList()
        # 对应原文的decoder中的gUnPool
        self.unpools = nn.ModuleList()
        # ks的长度表示了gUNets的深度
        self.l_n = len(ks)
        # 构建l_n个子模块
        for i in range(self.l_n):
            self.down_gcns.append(GCN(dim, dim, act, drop_p))
            self.up_gcns.append(GCN(dim, dim, act, drop_p))
            self.pools.append(Pool(ks[i], dim, drop_p))
            self.unpools.append(Unpool(dim, dim, drop_p))

    def forward(self, g, h):
    """
    : param g:邻接矩阵
    : param h:特征矩阵
    """
        adj_ms = []
        indices_list = [] # 用于存储TopK的idx
        down_outs = [] # 存储经过encoder中gcn产生的特征
        hs = []
        org_h = h # 原始的特征信息
        # Encoder执行部分
        for i in range(self.l_n):
            h = self.down_gcns[i](g, h) # h为encoder产生的特征
            adj_ms.append(g) # 存储输入第i层的邻接矩阵
            down_outs.append(h) # 存储第i层输出的特征矩阵
            # 经过图池化操作
            g, h, idx = self.pools[i](g, h) # idx存储了重要的TopK的节点信息
            indices_list.append(idx) # 存储第i层保留的idx的信息,用于decoder的信息还原
        # bottom gcn部分
        h = self.bottom_gcn(g, h)
        # Decoder执行部分
        for i in range(self.l_n):
			# 由于decoder部分将编码后的邻接矩阵从小恢复到大
            up_idx = self.l_n - i - 1
            # 分别取出与当前decoder block对应的邻接矩阵g,以及节点索引idx
            g, idx = adj_ms[up_idx], indices_list[up_idx]
            # 通过gUnPool操作恢复
            # 输出将增加邻接矩阵的维度
            g, h = self.unpools[i](g, h, down_outs[up_idx], idx)
            h = self.up_gcns[i](g, h)
            # h为第n_l-i层的特征输出
            # 将对应encoder block与decoder block进行skip connection
            h = h.add(down_outs[up_idx])
            # 存储经过跳跃连接之后的特征矩阵
            hs.append(h)
        h = h.add(org_h) # 将原始的特征矩阵信息与经过所有block输出的特征矩阵的信息进行skip connection
        hs.append(h)
        return hs # 输出最终的Embedding vector

在这里插入图片描述

  • encoder中的下采样操作
class Pool(nn.Module):

    def __init__(self, k, in_dim, p):
        super(Pool, self).__init__()
        self.k = k # 采样节点的比率
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()

    def forward(self, g, h):
        Z = self.drop(h)
        weights = self.proj(Z).squeeze()
        scores = self.sigmoid(weights)
        return top_k_graph(scores, g, h, self.k)
  • TopK算法实现
def top_k_graph(scores, g, h, k):
"""
: param scores:
: param g: 邻接矩阵
: param h: 特征矩阵
: param k: 选择前K个重要的节点
"""
    num_nodes = g.shape[0]
    # 返回前K个元素的值values,以及对应的索引值idx
    values, idx = torch.topk(scores, max(2, int(k*num_nodes)))
    new_h = h[idx, :]
    values = torch.unsqueeze(values, -1)
    new_h = torch.mul(new_h, values)
    un_g = g.bool().float()
    un_g = torch.matmul(un_g, un_g).bool().float()
    un_g = un_g[idx, :]
    un_g = un_g[:, idx]
    g = norm_g(un_g)
    return g, new_h, idx
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值