文章目录
各种机器学习应用中的许多数据集在其实体之间具有结构关系,可以表示为图。 比如社交和通信网络分析、流量预测和欺诈检测等。 图表示学习旨在为用于各种 ML 任务的图数据集构建和训练模型。
该
example
演示了
图神经网络 (GNN)模型的简单实现。 该模型在
Cora 数据集上进行节点预测任务,以根据其单词和引文网络预测论文主题。
我们从头开始实现图卷积层,以更好地理解它们的工作原理。 但是,有许多基于 TensorFlow 的专门库提供了丰富的 GNN API,例如 Spectral、StellarGraph 和 GraphNets。
Setup
import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
准备数据集
使用Cora dataset
,该数据集包括2708科学文章,并且分类为七类。citation network 有5429个links(链接)。每篇论文有一个大小为1433的二进制词向量,表示相对应的词。
该数据集有两个文件:cora.cites和cora.content
。
cora.cites
包括两列(cited_paper_id(target)
和citing_paper_id(source)
cora.content
包括1435列的paper content records:paper_id,subject
和1433二进制特征
下载数据集:
zip_file = keras.utils.get_file(
fname="cora.tgz",
origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
extract=True,
)
data_dir = os.path.join(os.path.dirname(zip_file), "cora")
运行如下:
处理和可视化数据集
加载citations data
到Pandas DataFrame中。
citations = pd.read_csv(
os.path.join(data_dir, "cora.cites"),
sep="\t",
header=None,
names=["target", "source"],
)
# print("Citations shape:", citations.shape) #(5429,2)
Display a sample of the citations
DataFrame. The target
列包括paper ids cited by the paper ids 在source
列。
citations.sample(frac=1).head()
然后加载paper data
到Pandas DataFrame。
column_names = ["paper_id"] + [f"term_{
idx}" for idx in range(1433)] + ["subject"]
papers = pd.read_csv(
os.path.join(data_dir, "cora.content"), sep="\t", header=None, names=column_names,
)
print("Papers shape:", papers.shape) # (2708, 1435)
Display a sample of papers
DataFrame. 该数据框包括paper_id
和subject
列以及1433二进制列表示在paper中是否存在一个term.
print(papers.sample(5).T)
然后,展示每一个subject中paper的数量
print(papers.subject.value_counts())
然后,转换paper ids和subjects到zero-based indices
.
class_values = sorted(papers["subject"].unique())
class_idx = {
name: id for id, name in enumerate(class_values)}
paper_idx = {
name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}
papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
接着,可视化citation graph
,图中的每个节点代表一篇paper,节点的颜色对应它的subject
,下面展示的是数据集当中的一个sample。
plt.figure(figsize=(10, 10))
colors = papers["subject"].tolist()
cora_graph = nx.from_pandas_edgelist(citations.sample(n=1500))
subjects = list(papers[papers["paper_id"].isin(list(cora_graph.nodes))]["subject"])
nx.draw_spring(cora_graph, node_size=15, node_color=subjects)
拆分数据集为分层训练集和测试集
train_data, test_data = [], []
for _, group_data in papers.groupby("subject"):
# Select around 50% of the dataset for training.
random_selection = np.random.rand(len(group_data.index)) <= 0.5
train_data.append(group_data[random_selection])
test_data.append(group_data[~random_selection])
train_data = pd.concat(train_data).sample(frac=1)
test_data =