Graph-U-Nets(二)代码分析
Graph U-Nets通过Pytorch进行实现,开源代码Graph U-Nets,原论文链接Graph U-Nets
直接对作者的代码进行解读
class GraphUnet(nn.Module):
def __init__(self, ks, in_dim, out_dim, dim, act, drop_p):
"""
:param ks: 表示pools层进行的节点采样率,数据类型为float型
"""
super(GraphUnet, self).__init__()
self.ks = ks
# 创建底部的gcn
self.bottom_gcn = GCN(dim, dim, act, drop_p)
# 对应原文的encoder中的gcn
self.down_gcns = nn.ModuleList()
# 对应原文的decoder中的gcn
self.up_gcns = nn.ModuleList()
# 对应原文的encoder中的gPool
self.pools = nn.ModuleList()
# 对应原文的decoder中的gUnPool
self.unpools = nn.ModuleList()
# ks的长度表示了gUNets的深度
self.l_n = len(ks)
# 构建l_n个子模块
for i in range(self.l_n):
self.down_gcns.append(GCN(dim, dim, act, drop_p))
self.up_gcns.append(GCN(dim, dim, act, drop_p))
self.pools.append(Pool(ks[i], dim, drop_p))
self.unpools.append(Unpool(dim, dim, drop_p))
def forward(self, g, h):
"""
: param g:邻接矩阵
: param h:特征矩阵
"""
adj_ms = []
indices_list = [] # 用于存储TopK的idx
down_outs = [] # 存储经过encoder中gcn产生的特征
hs = []
org_h = h # 原始的特征信息
# Encoder执行部分
for i in range(self.l_n):
h = self.down_gcns[i](g, h) # h为encoder产生的特征
adj_ms.append(g) # 存储输入第i层的邻接矩阵
down_outs.append(h) # 存储第i层输出的特征矩阵
# 经过图池化操作
g, h, idx = self.pools[i](g, h) # idx存储了重要的TopK的节点信息
indices_list.append(idx) # 存储第i层保留的idx的信息,用于decoder的信息还原
# bottom gcn部分
h = self.bottom_gcn(g, h)
# Decoder执行部分
for i in range(self.l_n):
# 由于decoder部分将编码后的邻接矩阵从小恢复到大
up_idx = self.l_n - i - 1
# 分别取出与当前decoder block对应的邻接矩阵g,以及节点索引idx
g, idx = adj_ms[up_idx], indices_list[up_idx]
# 通过gUnPool操作恢复
# 输出将增加邻接矩阵的维度
g, h = self.unpools[i](g, h, down_outs[up_idx], idx)
h = self.up_gcns[i](g, h)
# h为第n_l-i层的特征输出
# 将对应encoder block与decoder block进行skip connection
h = h.add(down_outs[up_idx])
# 存储经过跳跃连接之后的特征矩阵
hs.append(h)
h = h.add(org_h) # 将原始的特征矩阵信息与经过所有block输出的特征矩阵的信息进行skip connection
hs.append(h)
return hs # 输出最终的Embedding vector
- encoder中的下采样操作
class Pool(nn.Module):
def __init__(self, k, in_dim, p):
super(Pool, self).__init__()
self.k = k # 采样节点的比率
self.sigmoid = nn.Sigmoid()
self.proj = nn.Linear(in_dim, 1)
self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
def forward(self, g, h):
Z = self.drop(h)
weights = self.proj(Z).squeeze()
scores = self.sigmoid(weights)
return top_k_graph(scores, g, h, self.k)
- TopK算法实现
def top_k_graph(scores, g, h, k):
"""
: param scores:
: param g: 邻接矩阵
: param h: 特征矩阵
: param k: 选择前K个重要的节点
"""
num_nodes = g.shape[0]
# 返回前K个元素的值values,以及对应的索引值idx
values, idx = torch.topk(scores, max(2, int(k*num_nodes)))
new_h = h[idx, :]
values = torch.unsqueeze(values, -1)
new_h = torch.mul(new_h, values)
un_g = g.bool().float()
un_g = torch.matmul(un_g, un_g).bool().float()
un_g = un_g[idx, :]
un_g = un_g[:, idx]
g = norm_g(un_g)
return g, new_h, idx