使用 PyTorch Geometric 和 GCTConv实现异构图、二部图上的节点分类或者链路预测

解决问题描述

使用 PyTorch Geometric 和 Heterogeneous Graph Transformer 实现异构图上的节点分类
在二部图上应用GTN算法(使用torch_geometric的库HGTConv);

步骤解释

  1. 导入所需的 PyTorch 和 PyTorch Geometric 库。

  2. 定义 x1 和 x2 两种不同类型节点的特征,分别有 1000 个和 500 个节点,每个节点有两维特征。
    随机生成两种边 e1 和 e2 的索引(edge index)和权重(edge weight),其中 e1 从 n1 到 n2,e2 从 n2 到 n1。

  3. 定义异构图的元数据字典 meta_dict,其中 ‘n1’ 和 ‘n2’ 分别表示两种节点类型,而 (‘n1’, ‘e1’, ‘n2’) 表示从类型 ‘n1’ 的节点到类型 ‘n2’ 的节点有一条边,这条边的索引和权重分别为 edge_index_e1 和 edge_weight_e1。

  4. 利用元数据字典 meta_dict 创建异构图数据对象 data,并将节点特征和边索引添加到该对象中。

  5. 定义异构元数据列表 meta_list,其中包含所有节点类型和边类型的名称信息。

  6. 定义 HGTConv 层,并指定输入通道数、输出通道数、异构元数据列表以及头数等超参数。

  7. 将节点特征和边索引转换为字典形式,并利用 HGTConv

  8. 应用 HGTConv 到输入数据,得到输出结果 output_dict,其中包含了处理后的节点特征。最后打印输出 n1 和 n2 节点的输出形状。

详细代码

以下代码可以直接运行

import torch
from torch_geometric.data import Data, HeteroData
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import HGTConv

# 定义节点特征
x1 = torch.randn(1000, 2)
x2 = torch.randn(500, 2)

# 定义边索引(edge index)以及边权重(edge weight)
edge_index_e1 = torch.cat((torch.randint(0, 1000, size=(1, 4000)),torch.randint(0, 500, size=(1, 4000))),dim=0)
edge_weight_e1 = torch.rand(4000)
edge_index_e2=torch.flip(edge_index_e1, (0,))

# 定义元数据字典,描述异构图的结构
meta_dict = {
    'n1': {'num_nodes': x1.shape[0], 'num_features': x1.shape[1]},
    'n2': {'num_nodes': x2.shape[0], 'num_features': x2.shape[1]},
    ('n1', 'e1', 'n2'): {'edge_index': edge_index_e1, 'edge_weight': edge_weight_e1},
}

# 创建异构图数据对象
data = HeteroData(meta_dict)

# 将节点特征和边索引添加到异构图对象中
data['n1'].x = x1
data['n2'].x = x2
data[('n1', 'e1', 'n2')].edge_index = edge_index_e1
data[('n2', 'e1', 'n1')].edge_index = edge_index_e2

# 定义异构元数据列表
meta_list= (['n1', 'n2'], [('n1', 'e1', 'n2'), ('n2', 'e1', 'n1')])

# 定义 HGTConv 层
in_channels = {
    'n1': x1.shape[1],
    'n2': x2.shape[1],
}
out_channels = 16
heads = 4
conv = HGTConv(in_channels=in_channels, out_channels=out_channels, metadata=meta_list,heads=heads)

# 将输入数据转换为字典形式
x_dict = {ntype: data[ntype].x for ntype in data.node_types}
edge_index_dict = {}
for etype in data.edge_types:
    edge_index_dict[etype] = data[etype].edge_index

# 应用 HGTConv 到输入数据
output_dict = conv(x_dict, edge_index_dict)
print(output_dict['n1'].shape)
print(output_dict['n2'].shape)

之后如果是节点分类则:

output_dict的n1,n2特征编码分别接全连接层对应y1,y2

之后如果是链路预测则:

output_dict的n1,n2特征编码按照链路进行合并,进而预测

一些细节

data = HeteroData(meta_dict) 创建异构图对象
edge_index_e2=torch.flip(edge_index_e1, (0,)) 创建逆向的边,由于是二部图无向图所以需要

  • 4
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
Pytorch Geometric提供了一些内置的归一化和池化方法,可以方便地用于GCN模型中。下面分别介绍这些方法的用法。 1. 归一化 Pytorch Geometric提供了两种常见的归一化方法:对称归一化和随机游走归一化。 对称归一化: ```python import torch_geometric.transforms as T data = T.NormalizeSymm()(data) ``` 随机游走归一化: ```python import torch_geometric.transforms as T data = T.RandomWalk()(data) ``` 其中,`data`是一个包含数据的对象,比如`torch_geometric.data.Data`。 2. 池化 池化操作可以将一张大缩小到一张小,从而减少模型参数和计算量。Pytorch Geometric提供了几种常见的池化方法,比如TopK池化、SAG Pooling和Diff Pooling。 TopK池化: ```python import torch_geometric.nn.pool as pool x, edge_index, batch = pool.topk(x, ratio=0.5, batch=batch) ``` 其中,`x`是节点特征矩阵,`edge_index`是边的索引矩阵,`batch`是节点所属的的标识符。`ratio`是池化后每个保留的节点数占原节点数的比例。 SAG Pooling: ```python import torch_geometric.nn.pool as pool x, edge_index, _, batch, _, _ = pool.sag_pool(x, edge_index, batch) ``` 其中,`x`、`edge_index`和`batch`的含义同TopK池化。SAG Pooling使用节点嵌入向量计算每个节点的注意力权重,根据权重进行池化。 Diff Pooling: ```python import torch_geometric.nn as nn diffpool = nn.DiffPool(in_channels, hidden_channels, num_classes) x, edge_index, edge_attr, batch, perm, score = diffpool(x, edge_index) ``` 其中,`in_channels`是输入节点特征的维度,`hidden_channels`是池化后节点特征的维度,`num_classes`是分类的类别数。`x`、`edge_index`和`batch`的含义同TopK池化。Diff Pooling使用GraphSAGE卷积层计算每个节点的嵌入向量,根据嵌入向量进行池化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_刘文凯_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值