GNN学习笔记
GNN从入门到精通课程笔记
2.1 DeepWalk (Code)
DeepWalk: Online Learning of Social Representations (KDD’ 14)
用DeepWalk在Zachary’s Karate Network上构建Embedding, 用sklearn提供的LR模型分类。
import networkx as nx
from karateclub import DeepWalk
# Load Karate Club Graph
G = nx.karate_club_graph()
# Generate DeepWalk Embeddings
model = DeepWalk(walk_length=5, dimensions=128, window_size=5, epochs=20, workers=1)
model.fit(G)
embeddings = model.get_embedding()
# Generate X and Y for training
X = []
Y = []
for node in G.nodes():
X.append(embeddings[node])
Y.append(G.nodes[node]["club"])
# Split Train dataset and Test dataset
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y)
# Train a Logistic Regression Classifier
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(random_state=0).fit(X_train, Y_train)
# Evaluate the Classifier
accuracy = clf.score(X_test, Y_test)
print("Accuracy:", accuracy)
运行结果
Accuracy: 1.0