Blog10 图神经网络的模型级解释——关键代码分析2

2021SC@SDUSC

 

本篇分析代码模块为:model.py文件、test.py文件

在上一篇博客中,我们分析了加载数据集的文件,这篇文章,我们来对model.py和test.py文件进行分析。

一、引入包简介

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

pytorch提供了许多优雅的类和模块帮助我们构建与训练网络,如 torch.nntorch.optim,Dataset等。

  • torch.nn:该模块里面,除了大量的损失函数与激活函数,里面还包含了大量用于构建网络的函数。
  • torch.nn.Parameter():首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面,所以经过类型转换变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

二、GraphConvolution(nn.Module)类

该类定义了一个简单的 GCN 层,类似于 https://arxiv.org/abs/1609.02907。 

相关论文请参考: Semi-Supervised Classification with Graph Convolutional Networks

class GraphConvolution(nn.Module):
    
    # 模型的参数包括weight和bias
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = Parameter(torch.FloatTensor(out_features))
        self.reset_parameters()

super(GraphConvolution, self).__init__():这个比较拗口的语法点,其实意思很简单,首先找到GraphConvolution的父类(这里是类nn.Module),然后把类的对象self转换为类nn.Module的对象,然后“被转换”的类nn.Module对象调用自己的_init_函数。

这是对继承自父类的属性进行初始化。而且是用父类的初始化方法来初始化继承的属性。 

也就是说,子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。

    # 权重初始化
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    # 类似于tostring
    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

    # 计算A~ X W(0)
    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        return output + self.bias
  • input.shape = [max_node, features] = X
  • adj.shape = [max_node, max_node] = A~
  • torch.mm(a, b)是矩阵a和b矩阵相乘
  • torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等

三、GCN类

在前面博客中,我们分析了论文的相关内容,已经了解了GNN的思想,那下面来学习一下GCN的思想。如图:

 

平均法求节点特征只是简单得将该节点的所有邻居节点的特征拿来做平均,但是我们考虑上图这种情况。A只有B一个邻居,而B却广泛交友,导致自己身上聚合了太多他人的特征,并且这些特征与A关系并不大,我们想象这样一个情景(纯属博主虚构,对小明毫无恶意):社畜小明在阿里上班,有幸与马云爸爸有那么一丝关联,而现在小明就是A,马云爸爸就是B的话,小明与马云爸爸的差距就不言而喻了,但凡只用平均法将马云爸爸经过周边一系列富商聚合的最终特征来给小明当作邻居特征,结果只能是小明的梦想,却不能成为现状。

而GCN就是要解决这个问题,我们分析一下,出现上面这种情况的原因,就是B连接的邻居节点太多了,并且与A都没有关系——我朋友的朋友不是我的朋友。这就涉及到B节点度的问题,我们把图中的邻接矩阵用A~表示,节点的度矩阵用D~表示,那么可以推导出上图右侧的公式。具体的,放到我们上面的例子中,就是Dii代表A的度,Djj代表B的度,将结果除以他们乘机的开方,就可以在一定程度上解决我们上述描述的问题。

有了上述思想,我们来看一下GCN的代码。

class GCN(nn.Module):
    # feature的个数;最终的分类数

    def __init__(self, nfeat, nclass, dropout):
        super(GCN, self).__init__()
        self.dropout = dropout
        self.gc1 = GraphConvolution(nfeat, 32)
        self.gc2 = GraphConvolution(32, 48)
        self.gc3 = GraphConvolution(48, 64)
        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, nclass)
 
     def forward(self, x, adj):
        # x.shape = [max_node, features]
        # adj.shape = [max_node, max_node]
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc2(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc3(x, adj))

        y = torch.mean(x, 0)  # 采用mean作为聚合函数聚合所有结点的特征
        y = F.relu(self.fc1(y))
        y = F.dropout(y, self.dropout, training=self.training)
        y = F.softmax(self.fc2(y), dim=0)

        return y

F.dropout ( nn.functional.dropout ):Dropout是指在模型训练时随机让网络某些隐含层节点的权重不工作,不工作的那些节点可以暂时认为不是网络结构的一部分,但是它的权重得保留下来(只是暂时不更新而已),因为下次样本输入时它可能又得工作了。

使用的时候需要设置它的training这个状态参数与模型整体的一致.

F.relu():激活函数采用ReLu(Mutag)。

根据论文可知,3 层 GCN,输出维度分别等于 32, 48, 64 。并对所有节点特征求平均值。

最终分类器具有 2 个全连接层,并且隐藏维度设置为 32 。

最终,设置特征数为7,表示七个原子的类型,分类标签为2,代表对细菌有无致突变作用。得到输出output.size()如下: 

综上,model.py文件定义了一个GCN分类模型,该模型有三层GCN,输出维度分别等于 32, 48, 64  , 2 个全连接层。

四、模型测试

对上个代码文件定义的GCN模型进行效果评估。

if __name__ == '__main__':
    adj_list, features_list, graph_labels, idx_map, idx_train, idx_val, idx_test = load_split_MUTAG_data()
    model = GCN(nfeat=features_list[0].shape[1],  # nfeat = 7
                nclass=graph_labels.max().item() + 1,  # nclass = 2
                dropout=0.1)

    model.eval()
    outputs = []
    for i in idx_test:
        output = model(features_list[i], adj_list[i])
        output = output.unsqueeze(0)
        outputs.append(output)
    output = torch.cat(outputs, dim=0)

    loss_test = F.cross_entropy(output, graph_labels[idx_test])
    acc_test = accuracy(output, graph_labels[idx_test])
    print(loss_test)
    print(acc_test)

在测试集中,计算交叉熵损失和精度 并输出。打印结果如下:

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值