PyG框架:Graph Classification

训练GNN用来做Graph Classification

一、原理

1、根据Message Passing得到每个节点的node embedding

2、readout layer
把所有节点的node embedding聚合成整个图的graph embedding。
【文献中有很多种不同的readout layer,但最常用的是mean】
在这里插入图片描述

【跟Node Classification的区别】:是否把每个节点的node embedding聚合成一个graph embedding?

针对mini-batch,PyG框架有封装好的模块,torch_geometric.nn.global_mean_pool 可以分别将mini-batch中每个图的所有node embedding聚合成一个graph embedding(一个batch中有多少个图,就有多少个graph embedding)。一个batch的graph embedding矩阵的shape为:[batch_size,hidden_channels]。hidden_channels:一个graph embedding(向量)的长度

3、训练一个针对graph embedding的分类器

二、代码实现

PyG框架是什么?如何安装?可以参照官方文档or我的上一篇博客:https://blog.csdn.net/qq_38432089/article/details/122152640?spm=1001.2014.3001.5501
1、数据集准备

import torch
from torch_geometric.datasets import TUDataset

# 1、数据集下载
dataset = TUDataset(root='data/TUDataset', name='MUTAG')

# 查看数据集信息
print()
print(f'Dataset: {
     dataset}:')
print('====================')
print(f'Number of graphs: {
     len(dataset)}')
print(f'Number of features: {
     dataset.num_features}')
print(f'Number of classes: {
     dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(
  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值