训练GNN用来做Graph Classification
一、原理
1、根据Message Passing得到每个节点的node embedding
2、readout layer
把所有节点的node embedding聚合成整个图的graph embedding。
【文献中有很多种不同的readout layer,但最常用的是mean】
【跟Node Classification的区别】:是否把每个节点的node embedding聚合成一个graph embedding?
针对mini-batch,PyG框架有封装好的模块,torch_geometric.nn.global_mean_pool 可以分别将mini-batch中每个图的所有node embedding聚合成一个graph embedding(一个batch中有多少个图,就有多少个graph embedding)。一个batch的graph embedding矩阵的shape为:[batch_size,hidden_channels]。hidden_channels:一个graph embedding(向量)的长度
3、训练一个针对graph embedding的分类器
二、代码实现
PyG框架是什么?如何安装?可以参照官方文档or我的上一篇博客:https://blog.csdn.net/qq_38432089/article/details/122152640?spm=1001.2014.3001.5501
1、数据集准备
import torch
from torch_geometric.datasets import TUDataset
# 1、数据集下载
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
# 查看数据集信息
print()
print(f'Dataset: {
dataset}:')
print('====================')
print(f'Number of graphs: {
len(dataset)}')
print(f'Number of features: {
dataset.num_features}')
print(f'Number of classes: {
dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print()
print(data)
print('=============================================================')
# Gather some statistics about the first graph.
print(