【Graph Net】【专题系列】五、GraphSAGE代码实战
目录
一、简介
GraphSAGE(Graph Sample and Aggregation,2017) 是一种高效的图嵌入方法,可以用于不同的图学习任务,如节点分类、链接预测等。它允许学习到一种能够统一不同大小的图、不同类型的图结构以及训练和测试图数据的结构差异的节点嵌入向量。
算法流程大致可以分为以下步骤:
初始化:为图中每个节点赋予一个初始特征向量,这可以是节点的属性特征,或者简单的one-hot编码。
采样邻居:对于图中的每个节点,从其邻居中随机采样固定数目的节点。采样是为了减少计算资源的消耗,特别是对于那些有很多邻居的节点。采样数目可以是一个超参数。
聚合邻居信息:对于每个节点,根据它采样得到的邻居,使用一个聚合函数(例如均值、池化或LSTM等)来聚合邻居的特征表示。
更新节点表示:用聚合得到的邻居信息来更新目标节点的特征表示。更新时,可以将目标节点的原始特征和聚合后的邻居特征结合起来,例如,通过拼接后进行非线性变换。
重复迭代:以上述步骤(从2到4)为一个迭代周期,重复多个周期。在每一步的迭代中,节点的表示将会越来越精细,并越来越好地编码了其邻域结构信息。
训练损失计算:如果是有监督任务,可以在每个epoch后计算损失函数,并通过反向传播优化模型参数。如果是无监督任务,则可能使用邻近节点间相似性作为损失函数。
生成嵌入:训练完成后,使用最终模型对每个节点聚合邻居信息,以生成节点的嵌入表示。
使用节点嵌入:根据不同的任务目标(如节点分类、节点聚类、链路预测等),使用生成的节点嵌入进行下一步的任务处理
二、代码
与前面【Graph Net学习】系列文章代码结构相似,不清楚可以看看前面文章。P.S.新增了Embedding的透传,可以直接拿到图节点的Embedding用于下游。直接贴代码,配置环境即可运行。
import os
import time
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GATConv
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
from torch_geometric.transforms import NormalizeFeatures
import numpy as np
import pandas as pd
import scipy.sparse as sp
from sklearn.preprocessing import LabelEncoder
#配置项
class configs():
def __init__(self):
# Data
self.data_path = r'./data/cora'
self.save_model_dir = './'
self.model_name = r'GAT'
self.seed = 2023
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.epoch = 50
self.in_features = 1433 #core ~ feature:1433
self.hidden_features = 16 # 隐层数量
self.output_features = 8 # core~paper-point~ 8类
self.learning_rate = 0.01
self.dropout = 0.5
self.istrain = True
self.istest = True
self.isembedding = True
cfg = configs()
def seed_everything(seed=2024):
random.seed(seed)
os.environ['PYTHONHASHSEED']=str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
seed_everything(seed = cfg.seed)
# 读取Cora数据集 return geometric Data格式
def index_to_mask(index, size):
mask = np.zeros(size, dtype=bool)
mask[index] = True
return mask
def load_cora_data(data_path = cfg.data_path):
content_df = pd.read_csv(os.path.join(data_path,"cora.content"), delimiter="\t", header=None)
content_df.set_index(0, inplace=True)
index = content_df.index.tolist()
features = sp.csr_matrix(content_df.values[:,:-1], dtype=np.float32)
# 处理标签
labels = content_df.values[:,-1]
class_encoder = LabelEncoder()
labels = class_encoder.fit_transform(labels)
# 读取引用关系
cites_df = pd.read_csv(os.path.join(data_path,"cora.cites"), delimiter="\t", header=None)
cites_df[0] = cites_df[0].astype(str)
cites_df[1] = cites_df[1].astype(str)
cites = [tuple(x) for x in cites_df.values]
edges = [(index.index(int(cite[0])), index.index(int(cite[1]))) for cite in cites]
edges = np.array(edges).T
# 构造Data对象
data = Data(x=torch.from_numpy(np.array(features.todense())),
edge_index=torch.LongTensor(edges),
y=torch.from_numpy(labels))
idx_train = range(140)
idx_val = range(200, 500)
idx_test = range(500, 1500)
data.train_mask = index_to_mask(idx_train, size=labels.shape[0])
data.val_mask = index_to_mask(idx_val, size=labels.shape[0])
data.test_mask = index_to_mask(idx_test, size=labels.shape[0])
return data
# GraphSAGE模型定义
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(GraphSAGE, self).__init__()
self.hidden_channels = cfg.hidden_features
self.conv1 = SAGEConv(in_channels, self.hidden_channels, aggr='mean') # 第一层GraphSAGE卷积层
self.conv2 = SAGEConv(self.hidden_channels, out_channels, aggr='mean') # 第二层GraphSAGE卷积层
def forward(self, data):
x,edge_index = data.x,data.edge_index
# 第一层卷积后应用ReLU激活函数
x = self.conv1(x, edge_index)
x = F.relu(x)
# 第二层卷积
x = self.conv2(x, edge_index)
embeddings = x # 保存第二层的输出作为节点嵌入
return x, embeddings
class myGraphSAGE_run():
def train(self):
t = time.time()
dataset = load_cora_data()
model = GraphSAGE(dataset.num_features, cfg.output_features).to(cfg.device)
data = dataset
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=5e-4)
model.train()
for epoch in range(cfg.epoch):
optimizer.zero_grad()
output,_ = model(data)
preds = output.max(dim=1)[1]
loss_train = F.cross_entropy(output[data.train_mask], data.y[data.train_mask].long())
correct = preds[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct / int(data.train_mask.sum())
loss_train.backward()
optimizer.step()
loss_val = F.cross_entropy(output[data.val_mask], data.y[data.val_mask].long())
correct = preds[data.val_mask].eq(data.y[data.val_mask]).sum().item()
acc_val = correct / int(data.val_mask.sum())
print('Epoch: {:04d}'.format(epoch + 1),
'loss_train: {:.4f}'.format(loss_train.item()),
'acc_train: {:.4f}'.format(acc_train),
'loss_val: {:.4f}'.format(loss_val.item()),
'acc_val: {:.4f}'.format(acc_val),
'time: {:.4f}s'.format(time.time() - t))
torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth')) # 模型保存
def infer(self):
#Create Test Processing
dataset = load_cora_data()
data = dataset
model_path = os.path.join(cfg.save_model_dir, 'latest.pth')
model = torch.load(model_path, map_location=torch.device(cfg.device))
model.eval()
output,_ = model(data)
params = sum(p.numel() for p in model.parameters())
preds = output.max(dim=1)[1]
loss_test = F.cross_entropy(output[data.test_mask], data.y[data.test_mask].long())
correct = preds[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct / int(data.test_mask.sum())
print("Test set results:",
"loss= {:.4f}".format(loss_test.item()),
"accuracy= {:.4f}".format(acc_test),
'params={:.4f}k'.format(params/1024))
def get_embedding(self):
# Create Test Processing
dataset = load_cora_data()
data = dataset
model_path = os.path.join(cfg.save_model_dir, 'latest.pth')
model = torch.load(model_path, map_location=torch.device(cfg.device))
model.eval()
with torch.no_grad(): # 关闭梯度计算
_, GraphSAGE_embeddings = model(data)
print("GraphSAGE_embeddings :",np.array(GraphSAGE_embeddings).shape) # print : "node2vec_embeddings : (2708, 64)"
if __name__ == '__main__':
mygraph = myGraphSAGE_run()
if cfg.istrain == True:
mygraph.train()
if cfg.istest == True:
mygraph.infer()
if cfg.isembedding == True:
mygraph.get_embedding()
三、实验结果及分析
Epoch: 0001 loss_train: 2.1139 acc_train: 0.2143 loss_val: 2.1312 acc_val: 0.1567 time: 0.6940s
Epoch: 0002 loss_train: 1.8440 acc_train: 0.4929 loss_val: 1.9777 acc_val: 0.2633 time: 0.7290s
Epoch: 0003 loss_train: 1.5150 acc_train: 0.6500 loss_val: 1.7834 acc_val: 0.4400 time: 0.7630s
Epoch: 0004 loss_train: 1.1952 acc_train: 0.8071 loss_val: 1.5940 acc_val: 0.5200 time: 0.7990s
Epoch: 0005 loss_train: 0.9246 acc_train: 0.8857 loss_val: 1.4305 acc_val: 0.5933 time: 0.8350s
...
Epoch: 0019 loss_train: 0.0265 acc_train: 0.9929 loss_val: 0.9076 acc_val: 0.7467 time: 1.5176s
Epoch: 0020 loss_train: 0.0237 acc_train: 0.9929 loss_val: 0.9118 acc_val: 0.7533 time: 1.7016s
Epoch: 0021 loss_train: 0.0214 acc_train: 1.0000 loss_val: 0.9152 acc_val: 0.7567 time: 1.7806s
Epoch: 0022 loss_train: 0.0196 acc_train: 1.0000 loss_val: 0.9177 acc_val: 0.7567 time: 1.8576s
Epoch: 0023 loss_train: 0.0181 acc_train: 1.0000 loss_val: 0.9196 acc_val: 0.7567 time: 2.0186s
...
Epoch: 0046 loss_train: 0.0071 acc_train: 1.0000 loss_val: 0.9245 acc_val: 0.7500 time: 2.9452s
Epoch: 0047 loss_train: 0.0069 acc_train: 1.0000 loss_val: 0.9252 acc_val: 0.7467 time: 2.9851s
Epoch: 0048 loss_train: 0.0067 acc_train: 1.0000 loss_val: 0.9257 acc_val: 0.7433 time: 3.0253s
Epoch: 0049 loss_train: 0.0066 acc_train: 1.0000 loss_val: 0.9261 acc_val: 0.7433 time: 3.0633s
Epoch: 0050 loss_train: 0.0064 acc_train: 1.0000 loss_val: 0.9262 acc_val: 0.7400 time: 3.1123s
Test set results: loss= 0.9726 accuracy= 0.7010 params=45.0547k
GraphSAGE_embeddings : (2708, 8)
基本上在epoch=22就稳定了。在测试集上效果为0.701,参数量在45.0547k
四、展望
GraphSAGE的这种采样与聚合策略使得模型能够有效处理大规模图数据,并可以泛化到未见过的节点上,因为节点的嵌入表示不依赖于整个图的结构,而只依赖于它的局部邻域。此外,该方法还允许在不同层使用不同的聚合函数来捕获更多样化的邻域信息。