本文参考自datawhale2021.6学习:图神经网络
【GNN】第三章 消息传递范式与PyG的MessagePassing基类
【GNN】第四章 节点表征学习与节点分类任务(理论+调包实操)
【GNN】第五章 构造数据完全存于内存的数据集类InMemoryDataset
本章目录
1 边预测任务
目标是预测两个节点间是否存在边
1.1 训练集、验证集、测试集的构建
需求:
- 正负样本平滑:
data.edge_index
存储了正样本,但为了构建预测任务还需要负样本(不存在边的节点对),同时正负样本数量要平衡 - 分训练、验证、测试集
解决: torch_geometric.utils.train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1)
- 返回六个属性取代
edge_index
:train_pos_edge_index 、train_neg_adj_mask、
val_pos_edge_index、val_neg_edge_index、test_pos_edge_index
和test_neg_edge_index - train_neg_adj_mask没用
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
dataset = Planetoid('dataset','Cora',transform=T.NormalizeFeatures())
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
for key in data.keys:
print(key, getattr(data,key).shape)
"""
class A(object):
bar = 1
a = A()
getattr(a, 'bar')
# 1
getattr(a, 'bar2', 3)
# bar2不存在,但设置了默认值3
"""
Cora是无向图:
- 统计边数量时正反方向各统计一次
- 训练集也包含了正反方向,但验证集与测试集只包含一个方向
- 理由:训练集要使网络学习出节点间信息流的传递,只考虑一个方向就会使信息缺失;而验证集和测试集只做网络能力的检验作用,只考虑一个方向即可
1.2 边预测神经网络(以GCN为例)
1.2.1 边预测的模型构造
三部分:
- 编码(encode):生成节点表征
- 解码(decode):根据两端点节点表征,计算有边的概率,计算方法在下述代码中是通过头尾端点属性对应相乘后求和
- 推理(decode_all):计算所有节点彼此有边的概率,方法同解码。torch.nonzero()
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
class Net(nn.Module):
def __init__(self, in_channels, out_channels):
super(Net,self).__init__()
self.conv1 = GCNConv(in_channels, 128)
self.conv2 = GCNConv(128, out_channels)
def encode(self, x, edge_index): # 节点表征学习
x = self.conv1(x,edge_index)
x = x.relu()
x = self.conv2(x,edge_index)
return x
def decode(self, z, pos_edge_index, neg_edge_index)