GNN学习笔记(四):图注意力神经网络(GAT)节点分类任务实现

目录

0 引言

1、Cora数据集

2、citeseer数据集

3、Pubmed数据集

4、DBLP数据集

5、Tox21 数据集

6、代码


嘚嘚嘚,唠叨小主,闪亮登场,哈哈,过时了过时了,闪亮登场换成大驾光临,哈哈,这样才颇有气势,哼哼...

(唠叨小主哼哼了两声,对新改的词表示满意,微拉裙侧,留下了高跟鞋的声音...)

0 引言

近年来,人们对深度学习方法在图上的扩展越来越感兴趣。在多方因素的成功推动下,研究人员借鉴了卷积网络、循环网络和深度自动编码器的思想,定义和设计了用于处理图数据的神经网络结构,由此一个新的研究热点——“图神经网络(Graph Neural Networks,GNN)”应运而生。

图神经网络的研究与图嵌入或网络嵌入密切相关,图嵌入或网络嵌入是数据挖掘和机器学习界日益关注的另一个课题。许多图嵌入算法通常是无监督的算法,它们可以大致可以划分为三个类别,即矩阵分解、随机游走和深度学习方法。同时图嵌入的深度学习方法也属于图神经网络,包括基于图自动编码器的算法(如DNGR和SDNE)和无监督训练的图卷积神经网络(如GraphSage)。我们将图神经网络划分为五大类别,分别是:图卷积网络(Graph Convolution Networks,GCN)、 图注意力网络(Graph Attention Networks)、图自编码器( Graph Autoencoders)、图生成网络( Graph Generative Networks) 和图时空网络(Graph Spatial-temporal Networks)。


今天的任务——参照GNN学习笔记中的代码使用PyG中的图卷积模块PyG的数据集上实现节点分类或回归任务,之前用到的是MLP、图卷积神经网络、图注意力神经网络,数据集是Cora数据集。

1、Cora数据集

Cora数据集由机器学习论文组成,是近年来图深度学习很喜欢使用的数据集。在数据集中,论文分为以下七类之一:基于案例、遗传算法、神经网络、概率方法、强化学习、规则学习、理论。论文的选择方式是,在最终语料库中,每篇论文引用或被至少一篇其他论文引用。整个语料库中有2708篇论文。在词干堵塞和去除词尾后,只剩下1433个独特的单词。文档频率小于10的所有单词都被删除。cora数据集包含1433个独特单词,所以特征是1433维。0和1描述的是每个单词在paper中是否存在。

数据下载地址:https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz

2、citeseer数据集

数据集下载地址:http://www.cs.umd.edu/~sen/lbc-proj/data/citeseer.tgz 

3、Pubmed数据集

PubMed数据集包括来自Pubmed数据库的19717篇关于糖尿病的科学出版物,分为三类:

Diabetes Mellitus, Experimental
Diabetes Mellitus Type 1
Diabetes Mellitus Type 2
引文网络由44338个链接组成。数据集中的每个出版物都由一个由500个唯一单词组成的字典中的TF/IDF加权词向量来描述。
数据集下载地址:https://linqs-data.soe.ucsc.edu/public/Pubmed-Diabetes.tgz

4、DBLP数据集

DBLP数据集用XML描述,字段信息包括:author、title、pages、year、booktitle、url、crossref、publisher、ee、cdrom、isbn、cite_label等。其中作者名属性信息的格式是统一的,处理比较方便。目前,DBLP对作者重名问题的处理已经有不错的效果。例如:输入一作者名“wei wang”,可以得到16个不同的作者及其工作单位,并能链接得到每个作者的发表论文情况、个人主页和合作者列表等信息。(不存在问题了吗?)此外,引文信息中除了基本信息:作者名、文章名、会议名之外,加入新的信息:author keywords,对应于论文中的keywords。但是,并非所有的论文都包含有author keywords信息,也并非所有作者都有个人主页,在个人主页链接识别上还存在问题。 

5、Tox21 数据集

此数据集来源于一个PubChem网站的一个2014年的竞赛:https://tripod.nih.gov/tox21/challenge/about.jsp
PubChem是美国国立卫生研究院(NIH)的开放化学数据库,是世界上最大的免费化学物信息集合。
PubChem的数据由数百个数据源提供,包括:政府机构,化学品供应商,期刊出版商等。

数据集可在此下载:https://tripod.nih.gov/tox21/challenge/data.jsp#

训练集和测试集都是由多个分子结构构成的sdf格式的文件。


6、代码

我们今天基于Citeseer数据构建图注意力神经网络模型。

# 获取并分析数据集
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GATConv
 
dataset = Planetoid(root='data/Planetoid', name='citeseer', transform=NormalizeFeatures())

###神经网络的构造
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, dataset.num_classes)
        # self.lin1 = Linear(dataset.num_features, hidden_channels)
        # self.lin2 = Linear(hidden_channels, dataset.num_classes)


    def forward(self, x, edge_index):
        # x = self.lin1(x)
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        #x = self.lin2(x)
        return x
 
model = GCN(hidden_channels=16)
print(model)


###模型的训练
model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) ##定义Adam优化器
criterion = torch.nn.CrossEntropyLoss() ##交叉熵损失

def train():
      model.train()
      optimizer.zero_grad()  # Clear gradients.
      out = model(data.x, data.edge_index)  # Perform a single forward pass.
      loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
      loss.backward()  # Derive gradients.
      optimizer.step()  # Update parameters based on gradients.
      return loss

for epoch in range(1, 201):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

##模型的测试 这部分与MLP神经网络相同
def test():
      model.eval()
      out = model(data.x, data.edge_index)
      pred = out.argmax(dim=1)  # Use the class with highest probability.
      test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.
      test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.
      return test_acc

test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')

##可视化
model.eval()

out = model(data.x, data.edge_index)
visualize(out, color=data.y)




参考资料:
【1】https://zhuanlan.zhihu.com/p/75307407?from_voters_page=true

【2】https://blog.csdn.net/yyl424525/article/details/100831452

【3】https://blog.csdn.net/yyl424525/article/details/100831452

【4】https://blog.csdn.net/qq_32797059/article/details/106577815

  • 3
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值