Blog13 图神经网络的模型级解释——关键代码分析5+大总结

本文详细分析了TrainGenerator.py模块,探讨了图神经网络(GNN)的训练过程,包括梯度下降的直观解释和参数更新。通过训练,GNN模型能捕获关键图形模式,例如在化学化合物中的碳环和NO2基团,用于预测其性质。文章总结了项目经验,指出解释方法XGNN有助于理解和改进GNN模型。
摘要由CSDN通过智能技术生成

2021SC@SDUSC



本篇代码分析模块为:TrainGenerator.py

前面一篇博客中,我们分析了图生成器的定义过程,为了是模型达到更好的效果,我们下面来看一下作者是如何进行训练的。

 一.训练图生成器

(1)一些初始化设置

 # T = 10
max_gen_step = 10 
candidate_set = ['C.4', 'N.5', 'O.2', 'F.1', 'I.7', 'Cl.7', 'Br.5']  

'C.4'表明碳原子的度不超过4。

(2)梯度下降更新参数,训练模型

这里实际上涉及到了梯度下降的方法,我们首先来看看梯度下降的一个直观的解释。比如我们在一座大山上的某处位置,由于我们不知道怎么下山,于是决定走一步算一步,也就是在每走到一个位置的时候,求解当前位置的梯度,沿着梯度的负方向,也就是当前最陡峭的位置向下走一步,然后继续求解当前位置梯度,向这一步所在位置沿着最陡峭最易下山的位置走一步。这样一步步的走下去,一直走到觉得我们已经到了山脚。当然这样走下去,有可能我们不能走到山脚,而是到了某一个局部的山峰低处。

从上面的解释可以看出,梯度下降不一定能够找到全局的最优解,有可能是一个局部最优解。当然,如果损失函数是凸函数,梯度下降法得到的解就一定是全局最优解。

## 训练generator
def train_generator(c=0, max_nodes=5):
    g.c = c
    for i in range(max_gen_step):
        optimizer.zero_grad()
        G = copy.deepcopy(g.G)
        p_start, a_start, p_end, a_end, G = g.forward(G)

        Rt = g.calculate_reward(G)
        loss = g.calculate_loss(Rt, p_start, a_start, p_end, a_end, G)
        loss.backward()
        optimizer.step()

        if G['num_nodes'] > max_nodes:
            g.reset_graph()
        elif Rt > 0:
            g.G = G

optimizer.zero_grad():意思是把梯度置零,也就是把loss关于weight的导数变成0。因为一个batch的loss关于weight的导数是所有sample的loss关于weight的导数的累加和。

copy.deepcopy():深拷贝,即我们通常理解的字面"复制"的意思。deepcopy的时候会将复杂对象的每一层复制一个单独的个体出来。

g.calculate_loss():算一下loss标量出来,确定优化的目标,才能进行反向传播。反向传播基于一个loss标量,loss从中获取grad。

loss.backward():反向传播求梯度

optimizer.step():更新所有参数

下图是我做的一些笔记,解释了上述过程:

5.画图

## 画图
def display_graph(G):
    G_nx = nx.from_numpy_matrix(np.asmatrix(G['adj'][:G['num_nodes'], :G['num_nodes']].numpy()))
    # nx.draw_networkx(G_nx)

    layout=nx.spring_layout(G_nx)
    nx.draw(G_nx, layout)

    coloring=torch.argmax(G['feat'],1)
    colors=['b','g','r','c','m','y','k']

    for i in range(7):
        nx.draw_networkx_nodes(G_nx,pos=layout,nodelist=[x for x in G_nx.nodes() if coloring[x]==i],node_color=colors[i])
        nx.draw_networkx_labels(G_nx,pos=layout,labels={x:candidate_set[i].split('.')[0] for x in G_nx.nodes() if coloring[x]==i})
    nx.draw_networkx_edges(G_nx,pos=layout,width=list(nx.get_edge_attributes(G_nx,'weight').values()))
    nx.draw_networkx_edge_labels(G_nx,pos=layout,edge_labels=nx.get_edge_attributes(G_nx, "weight"))

    plt.show()

draw_networkx_nodes():

G必选网络图
pos字典,可选以节点为键,位置为值的字典 ,位置应为长度为2的序列
axMatplotlib Axes对象,可选在指定的Matplotlib轴上绘制图形。
nodelist列表,可选(默认G.nodes())仅绘制指定的节点

6.生成结果


if __name__ == '__main__':
    g = Generator(model_path = model_path, C = candidate_set, node_feature_dim=7 ,c=0, start=0)
    optimizer = optim.Adam(g.parameters(), lr=lr, betas=(b1, b2))

    for i in range(1, 10):
        ## 生成最多分别包括i个结点的图结构
        g.reset_graph()
        train_generator(c=1, max_nodes=i)
        to_display = generate_graph(c=1, max_nodes=i)
        display_graph(to_display)
        print(g.model(to_display['feat'], to_display['adj']))

生成最大节点分别为1-10的预测子图结构。

 

 

  

     

 二、结果分析

在论文中,作者将初始图设置为单个碳原子,因为通常任何有机化合物都含有碳 。并给出了具有不同节点限制和 GNN 预测概率的生成图。

 第一行报告对“非诱变”类的解释,而第二行显示 “诱变”类。

对于诱变,我们可以观察到碳环和 NO2 是一些常见模式,这与碳环和 NO2 化学基团具有致突变性的化学事实一致 。这样的观察表明,经过训练的 GNN 分类器可以捕获这些关键图形模式以进行预测。

对于非诱变”类,我们观察到原子氯广泛存在于生成的图中,氯、溴和氟的组合总是导致非诱变预测。通过分析这些解释,我们可以更好地理解经过训练的 GNN 模型。

此外, 还探索了不同的初始图结果。我们将最大节点限制固定为 5 并生成对“mutagenic”类的解释。 首先,无论我们如何设置初始图,我们提出的方法总能找到最大化“诱变”类的预测概率的图模式。 对于前 5 个图形,这意味着初始图形设置为碳、氮、氧、碘或氟的单个节点,一些生成的图形仍然具有常见的模式,如碳环和 NO2 化学基团。我们的观察进一步证实了这些关键模式由经过训练的 GNN 捕获。

此外,我们注意到生成器仍然可以生成带有氯的图,这些图被预测为“诱变”,这与我们上面的结论相反。如果所有含有氯的图都应该被识别为非诱变的,这样的解释就表明了经过训练的 GNN 的局限性。然后这些生成的解释可以为改进训练的 GNN 提供指导,例如,我们在训练 GNN 时可能会更加重视图氯。

三、大总结

1.项目经历总结

本次项目复现了模型级可解释图神经网络的论文的工作,分析并跑通了源码,最终结合作者给出的实验参照,验证了实验结果。由于作者并未提供论文中实验所用的源码,因此我们得出的效果与论文中还是有很大差异,但是不影响我们理解模型的思想。

总结一下,在这项工作中,作者提出了一种新方法 XGNN 来解释模型级别的图模型。具体来说,我们建议通过图生成器找到可以最大化某个预测的图模式。我们将其表述为强化学习问题并迭代生成图模式。我们训练一个图生成器,对于每一步,它都会预测如何在当前图中添加一条边。此外,我们结合了几个图规则来鼓励生成的图是有效的和人类可理解的。最后,我们对合成数据集和真实数据集进行了实验,以证明我们提出的 XGNN 的有效性。

最终实验结果表明,作者提出的解释方法 XGNN 可以帮助验证、理解甚至帮助改进训练的 GNN 模型,以正确捕获我们想要的模式。 

2.课程收获总结

通过软件工程应用与实践的课程,我有了这次机会去真正接触项目,接触科研,了解一个项目的源码究竟与我们平时的程序有什么不同。

从第一次阅读一篇论文的无从下手,到慢慢摸索、请教老师学长,掌握了如何把握论文的重点脉络,从长篇大论中提取关键有效的信息。从刚开始配置项目环境,频频报错的不知所措,到一步一步学会调试,学会Debug程序查看文件之间的关联,将项目跑通,进一步理解了代码的思想真正与开发人员或者说代码的编写者进行思想的交流。

13余篇博客不仅仅记录了我的学习分析过程,也体现了我的缓慢成长过程,两篇论文的阅读学习,两个项目关键代码的分析,也让我对NLP和深度学习两个领域有了具体的探索,为我接下来的学习提供了参考方向。

每个读着代码找逻辑的日日夜夜,每次看着公式找思路的冥思苦想,每次找遍全网死磕模型的轴里轴气,塑造了现在的值得。

感谢每一篇认真分享的博主们,感谢老师学长和优秀的家人们,本期博客完结撒花!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值