deep graph infomax代码阅读总结

ICLR 2019。
ps:我觉得论文看method看不大懂,不如直接去看代码最清楚。

1.一种无监督的训练方式,核心:最大化互信息。(全图的信息与正样本局部信息最大化,全图的信息与负样本局部信息最小化。)
2.大致流程:通过最大化互信息训练图嵌入结果(无监督),训练线性分类器(有监督)完成图分类任务。
3.DGI定义:

class DGI(nn.Module):
    # ft_size, hid_units, nonlinearity
    def __init__(self, n_in, n_h, activation):
        super(DGI, self).__init__()
        self.gcn = GCN(n_in, n_h, activation)
        self.read = AvgReadout() #读出函数,其实这里就是所有节点表示的均值

        self.sigm = nn.Sigmoid()

        self.disc = Discriminator(n_h) #判别器,定义为一个双线性函数bilinear

    def forward(self, seq1, seq2, adj, sparse, msk, samp_bias1, samp_bias2): # msk: None, samp_bias1: None, samp_bias2: None,

        h_1 = self.gcn(seq1, adj, sparse)

        c = self.read(h_1, msk)
        c = self.sigm(c) # c表示全图信息

        h_2 = self.gcn(seq2, adj, sparse)

        ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2) #计算c-h_1,c_h_2的双线性判别器的结果

        return ret

    # Detach the return variables
    def embed(self, seq, adj, sparse, msk):
        h_1 = self.gcn(seq, adj, sparse)
        c = self.read(h_1, msk)

        return h_1.detach(), c.detach() #将tensor从计算图中分离出来,不参与反向传播

4.Discriminator定义

class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1) # 双线性层 x_1 W x_2 + b, 输出 batch * 1 维度,相当于输出表示两个输入之间的关系?

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None): #c应该与h_pl接近,与h_mi远离?

        c_x = torch.unsqueeze(c, 1)
        c_x = c_x.expand_as(h_pl)

        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 2)

        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 2)

        if s_bias1 is not None:
            sc_1 += s_bias1
        if s_bias2 is not None:
            sc_2 += s_bias2

        logits = torch.cat((sc_1, sc_2), 1)

        return logits

4.训练:
正样例:features
负样例:打乱了顺序的features

model = DGI(ft_size, hid_units, nonlinearity) #模型的创建
b_xent = nn.BCEWithLogitsLoss()

lbl_1 = torch.ones(batch_size, nb_nodes) #正样本标签
    lbl_2 = torch.zeros(batch_size, nb_nodes) #负样本标签
    lbl = torch.cat((lbl_1, lbl_2), 1) # shape: torch.Size([1, 5416])

logits = model(features, shuf_fts, sp_adj if sparse else adj, sparse, None, None, None) 
# shuf_fts代表打乱了顺序的features
loss = b_xent(logits, lbl)
  • 3
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Deep Graph Infomax是一篇由Petar Veličković等人于2019年在ICLR上发表的论文。该论文提出了一种基于图对比学习的方法,旨在学习图数据的表示。方法中使用了一个GNN Encoder来将图的节点编码为向量表示,通过一个Read-out函数将节点表示汇总为整个图的表示向量。同时,该方法对原始图进行扰动,并使用相同的GNN Encoder对扰动后的图进行编码,然后通过一个Decoder来使图的表示与原始图的节点表示更接近,并使扰动后的图的节点表示与原始图的节点表示更加疏远。这篇论文的贡献是提出了一种基于互信息最大化的自监督图学习通用框架。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [图对比学习三篇顶会论文](https://blog.csdn.net/qq_51072801/article/details/130251996)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 50%"] - *2* [论文阅读Deep Graph Infomax(DGI)》](https://blog.csdn.net/m0_71014828/article/details/125199457)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值