节点的预测任务
首先定义图神经网络的网络结构,这里使用了torch_geometric.nn.Sequential
容器,详细内容可见于官方文档。我们通过hidden_channels_list
参数来设置每一层GATConv
的outchannel
,通过修改hidden_channels_list
,我们就可构造出不同的图神经网络。这里通过了三种方式改变网络结构
- 1、使用不同的卷积层
- 2、使用不同的层数
- 3、每层的不同的神经元个数
1、通过不同的卷积层
这里使用了GCNConv、GATConv、SAGEConv、GraphConv和TransformerConv五种卷积核,运行效果如下:
# 先定义数据
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
dataset = Planetoid(root='dataset', name='Cora', transform=NormalizeFeatures())
print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print()
print(data)
print('======================')
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
####### 定义GCN网络结构
class GCN(torch.nn.Module):
def __init__(self, num_features, hidden_channels_list, num_classes):
super(GCN, self).__init__()
torch.manual_seed(12345)
hns = [num_features] + hidden_channels_list
conv_list = []
for idx in range(len(hidden_channels_list)):
conv_list.append((GCNConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
conv_list.append(ReLU(inplace=True),)
self.convseq = Sequential('x, edge_index', conv_list)
self.linear = Linear(hidden_channels_list[-1], num_classes)
def forward(self, x, edge_index):
x = self.convseq(x, edge_index)
x = F.dropout(x, p=0.5, training=self.training)
x = self.linear(x)
return x
其他网络层结构与上面类似,最终实验结果如图所示
2、使用不同层数
- hidden_channels_list=[2]
- hidden_channels_list=[100]
- hidden_channels_list=[200, 100]
- hidden_channels_list=[200, 100, 50]
- hidden_channels_list=[500, 200, 100]
- hidden_channels_list=[500, 200, 100, 50]
layers = [[2],
[100],
[200, 100],
[200, 100, 50],
[500, 200, 100],
[500, 200, 100, 50]]
def test_case(layer):
model = GCN(num_features=dataset.num_features, hidden_channels_list=layer, num_classes=dataset.num_classes).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(1, 301):
loss = train()
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
for layer in layers:
test_case(layer)
3、每层使用不同的神经元
- hidden_channels_list=[2]
- hidden_channels_list=[20]
- hidden_channels_list=[200]
layers = [[2],
[20],
[200]]
def test_case(layer):
model = GCN(num_features=dataset.num_features, hidden_channels_list=layer, num_classes=dataset.num_classes).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(1, 101):
loss = train()
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
for layer in layers:
test_case(layer)
边的预测任务
边预测任务,目标是预测两个节点之间是否存在边。拿到一个图数据集,我们有节点属性x
,边端点edge_index
。edge_index
存储的便是正样本。为了构建边预测任务,我们需要生成一些负样本,即采样一些不存在边的节点对作为负样本边,正负样本数量应平衡。此外要将样本分为训练集、验证集和测试集三个集合。
PyG中为我们提供了现成的采样负样本边的方法,train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1)
,其
- 第一个参数为
torch_geometric.data.Data
对象, - 第二参数为验证集所占比例,
- 第三个参数为测试集所占比例。
import os.path as osp
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from sklearn.metrics import roc_auc_score
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling, train_test_split_edges
###### 用Squential实现
# class GAT(torch.nn.Module):
# def __init__(self, num_features, hidden_channels_list, num_classes):
# super(GAT, self).__init__()
# torch.manual_seed(12345)
# hns = [num_features] + hidden_channels_list
# conv_list = []
# for idx in range(len(hidden_channels_list)):
# conv_list.append((GATConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
# conv_list.append(ReLU(inplace=True),)
# self.convseq = Sequential('x, edge_index', conv_list)
# self.linear = Linear(hidden_channels_list[-1], num_classes)
# def forward(self, x, edge_index):
# x = self.convseq(x, edge_index)
# x = F.dropout(x, p=0.5, training=self.training)
# x = self.linear(x)
# return x
class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels_list, out_channels):
super(Net, self).__init__()
hns = [in_channels] + hidden_channels_list
conv_list = []
for idx in range(len(hidden_channels_list)):
conv_list.append((GCNConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
conv_list.append(ReLU(inplace=True),)
self.convseq = Sequential('x, edge_index', conv_list)
self.linear = Linear(hidden_channels_list[-1], out_channels)
def encode(self, x, edge_index):
x = self.convseq(x, edge_index)
x = F.dropout(x, p=0.5, training=self.training)
x = self.linear(x)
return x
def decode(self, z, pos_edge_index, neg_edge_index):
edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()
def get_link_labels(pos_edge_index, neg_edge_index):
num_links = pos_edge_index.size(1) + neg_edge_index.size(1)
link_labels = torch.zeros(num_links, dtype=torch.float)
link_labels[:pos_edge_index.size(1)] = 1.
return link_labels
def train(data, model, optimizer):
model.train()
neg_edge_index = negative_sampling(
edge_index=data.train_pos_edge_index,
num_nodes=data.num_nodes,
num_neg_samples=data.train_pos_edge_index.size(1))
train_neg_edge_set = set(map(tuple, neg_edge_index.T.tolist()))
val_pos_edge_set = set(map(tuple, data.val_pos_edge_index.T.tolist()))
test_pos_edge_set = set(map(tuple, data.test_pos_edge_index.T.tolist()))
if (len(train_neg_edge_set & val_pos_edge_set) > 0) or (len(train_neg_edge_set & test_pos_edge_set) > 0):
print('wrong!')
optimizer.zero_grad()
z = model.encode(data.x, data.train_pos_edge_index)
link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index)
link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index).to(data.x.device)
loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
loss.backward()
optimizer.step()
return loss
@torch.no_grad()
def test(data, model):
model.eval()
z = model.encode(data.x, data.train_pos_edge_index)
results = []
for prefix in ['val', 'test']:
pos_edge_index = data[f'{prefix}_pos_edge_index']
neg_edge_index = data[f'{prefix}_neg_edge_index']
link_logits = model.decode(z, pos_edge_index, neg_edge_index)
link_probs = link_logits.sigmoid()
link_labels = get_link_labels(pos_edge_index, neg_edge_index)
results.append(roc_auc_score(link_labels.cpu(), link_probs.cpu()))
return results
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = 'Cora'
path = './dataset'
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
# dataset = 'Cora'
# path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
# dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
# data = dataset[0]
ground_truth_edge_index = data.edge_index.to(device)
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
data = data.to(device)
model = Net(dataset.num_features, [256, 128], 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
best_val_auc = test_auc = 0
for epoch in range(1, 3001):
loss = train(data, model, optimizer)
val_auc, tmp_test_auc = test(data, model)
if val_auc > best_val_auc:
best_val_auc = val_auc
test_auc = tmp_test_auc
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
f'Test: {test_auc:.4f}')
z = model.encode(data.x, data.train_pos_edge_index)
final_edge_index = model.decode_all(z)
if __name__ == "__main__":
main()
运行结果(这里出现wrong的原因是存在训练集负样本与验证集负样本存在交集,或训练集负样本与测试集负样本存在交集):
本文内容主要来自datawhale开源课程