GraphSAGE无监督学习DGL实现简单梳理

DGL中master分支2020.08.20版本的GraphSAGE无监督的实现梳理。因为master分支变化很大,所以可能以后代码会不太一样。
代码地址:https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling_unsupervised.py

1.采样是根据边的id来采的,而且使用了整个graph的所有边。

n_edges = g.number_of_edges()
train_seeds = np.arange(n_edges)

具体的dataloader(即得到每个batch真正训练的数据)代码如下:

    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler, exclude='reverse_id',
        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
        reverse_eids=th.cat([
            th.arange(n_edges // 2, n_edges),
            th.arange(0, n_edges // 2)]),
        negative_sampler=NegativeSampler(g, args.num_negs),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        pin_memory=True,
        num_workers=args.num_workers)

训练时得到的一个batch训练数据代码如下:

for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):

这里整体的流程应该如下:

  • Dataloader得到train_seeds(graph中所有边的id),每次获取一个batch_size数量的e_id,根据这个e_id得到其两头的结点srcdst,构建一个正样本的子图pos_graph,负样本的子图neg_graph则是通过NegativeSampler,随机替换掉dst构建而成的,假设替换为了dst_neg需要注意的是,pos_graphneg_graph最终包含的结点其实都是src,dstdst_neg(其中的边关系该是怎么样还是怎么样,原因是计算loss的时候需要,可以直接把算出的特征赋值给pos_graphneg_graph),最终将以src,dstdst_neg一起作为seeds,进行sage的子图采样,采样完成的最外层结点会通过input_nodes返回,用于取出对应结点的特征。

2.loss计算
代码如下:

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            neg_score = neg_graph.edata['score']

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

可以看到,最终sage得到的每一个batch的输出block_outpus直接赋值给了pos_graphneg_graphndata['h'],这里可以直接赋值的原因就是因为pos_graphneg_graph中的结点个数和block_outputs的维度相同,因为是以这两个图中的结点作为seeds进行的邻居采样。

具体loss的计算,这里使用的是F.binary_cross_entropy_with_logits,和论文中的好像有一点不一样,但是效果应该是相同的。
论文中的公式为:

J g ( z u ) = − l o g ( σ ( z u T z v ) ) − Q ⋅ E v n ∼ P n ( v ) l o g ( σ ( − z u T z v n ) ) J_g(z_u)=-log(\sigma(z^T_uz_v))-Q·\mathbb{E}_{v_n\sim P_n(v)}log(\sigma(-z^T_uz_{v_n})) Jg(zu)=log(σ(zuTzv))QEvnPn(v)log(σ(zuTzvn))

代码中的F.binary_cross_entropy_with_logits公式为:

l n = − w n [ y n ⋅ l o g σ ( x n ) + ( 1 − y n ) ⋅ l o g ( 1 − σ ( x n ) ) ] l_n=-w_n[y_n·log\sigma(x_n)+(1-y_n)·log(1-\sigma(x_n))] ln=wn[ynlogσ(xn)+(1yn)log(1σ(xn))]

因为正样本的类别为1,负样本的类别为0,因此分别代入样本类别 y n y_n yn中,这两个公式主要区别就在于后半部分是 l o g ( σ ( − z u T z v n ) log(\sigma(-z^T_uz_{v_n}) log(σ(zuTzvn)还是 l o g ( 1 − σ ( − z u T z v n ) ) log(1-\sigma(-z^T_uz_{v_n})) log(1σ(zuTzvn))了,整体应该影响不大。

3.后续测试的时候,是通过一个逻辑回归,迭代1w次进行的,这里才区分了训练集和测试集。

TODO:弄明白EdgeDataLoader中的exlude参数的作用,为什么要提供reverse_eids

exclude参数主要是为了防止information leakage. reverse_eids就是要剔除的eid.

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值