关于Node2Vec的介绍有很多了,这里就不细述。本文主要是介绍如何用PyTorch Geometric快速实现Node2Vec节点分类,并对其结果进行可视化。
整个过程包含四个步骤:
- 导入图数据(这里以Cora为例)
- 创建Node2Vec模型
- 训练和测试数据
- TSNE降维后可视化
完整代码如下:
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Node2Vec
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,
context_size=10, walks_per_node=10, num_negative_samples=1,
sparse=True).to(device)
loader = model.loader(batch_size=128, shuffle=True, num_workers=4)
optimizer = torch.optim.SparseAdam(model.parameters(), lr=0.01)
def train():
model.train()
total_loss = 0
for pos_rw, neg_rw in loader:
optimizer.zero_grad()
loss = model.loss(pos_rw.to(device), neg_rw.to(device))
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
@torch.no_grad()
def test():
model.eval()
z = model()
acc = model.test(z[data.train_mask], data.y[data.train_mask],
z[data.test_mask], data.y[data.test_mask], max_iter=150)
return acc
for epoch in range(1, 101):
loss = train()
acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Acc: {acc:.4f}')
@torch.no_grad()
def plot_points(colors):
model.eval()
z = model(torch.arange(data.num_nodes, device=device))
z = TSNE(n_components=2).fit_transform(z.cpu().numpy())
y = data.y.cpu().numpy()
plt.figure(figsize=(8, 8))
for i in range(dataset.num_classes):
plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])
plt.axis('off')
plt.show()
colors = ['#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700']
plot_points(colors)
输出结果如下:
Epoch: 01, Loss: 8.0489, Acc: 0.1780
Epoch: 02, Loss: 6.0371, Acc: 0.2000
Epoch: 03, Loss: 4.9573, Acc: 0.2190
Epoch: 04, Loss: 4.1341, Acc: 0.2540
Epoch: 05, Loss: 3.4723, Acc: 0.2920
Epoch: 06, Loss: 2.9627, Acc: 0.3250
Epoch: 07, Loss: 2.5404, Acc: 0.3550
Epoch: 08, Loss: 2.2190, Acc: 0.3840
Epoch: 09, Loss: 1.9515, Acc: 0.4200
Epoch: 10, Loss: 1.7386, Acc: 0.4470
Epoch: 11, Loss: 1.5671, Acc: 0.4670
Epoch: 12, Loss: 1.4274, Acc: 0.4890
Epoch: 13, Loss: 1.3167, Acc: 0.5060
Epoch: 14, Loss: 1.2284, Acc: 0.5340
Epoch: 15, Loss: 1.1606, Acc: 0.5530
Epoch: 16, Loss: 1.1018, Acc: 0.5580
Epoch: 17, Loss: 1.0573, Acc: 0.5850
Epoch: 18, Loss: 1.0241, Acc: 0.5970
Epoch: 19, Loss: 0.9943, Acc: 0.6070
Epoch: 20, Loss: 0.9701, Acc: 0.6080
Epoch: 21, Loss: 0.9523, Acc: 0.6230
Epoch: 22, Loss: 0.9342, Acc: 0.6290
Epoch: 23, Loss: 0.9200, Acc: 0.6360
Epoch: 24, Loss: 0.9104, Acc: 0.6460
Epoch: 25, Loss: 0.9003, Acc: 0.6350
Epoch: 26, Loss: 0.8936, Acc: 0.6440
Epoch: 27, Loss: 0.8859, Acc: 0.6590
Epoch: 28, Loss: 0.8776, Acc: 0.6570
Epoch: 29, Loss: 0.8730, Acc: 0.6590
Epoch: 30, Loss: 0.8696, Acc: 0.6660
Epoch: 31, Loss: 0.8658, Acc: 0.6740
Epoch: 32, Loss: 0.8609, Acc: 0.6740
Epoch: 33, Loss: 0.8586, Acc: 0.6750
Epoch: 34, Loss: 0.8559, Acc: 0.6770
Epoch: 35, Loss: 0.8524, Acc: 0.6800
Epoch: 36, Loss: 0.8516, Acc: 0.6880
Epoch: 37, Loss: 0.8494, Acc: 0.6820
Epoch: 38, Loss: 0.8481, Acc: 0.6770
Epoch: 39, Loss: 0.8464, Acc: 0.6800
Epoch: 40, Loss: 0.8435, Acc: 0.6850
Epoch: 41, Loss: 0.8423, Acc: 0.6890
Epoch: 42, Loss: 0.8401, Acc: 0.6870
Epoch: 43, Loss: 0.8392, Acc: 0.6900
Epoch: 44, Loss: 0.8390, Acc: 0.6850
Epoch: 45, Loss: 0.8365, Acc: 0.6880
Epoch: 46, Loss: 0.8363, Acc: 0.6920
Epoch: 47, Loss: 0.8354, Acc: 0.6990
Epoch: 48, Loss: 0.8345, Acc: 0.6970
Epoch: 49, Loss: 0.8352, Acc: 0.6970
Epoch: 50, Loss: 0.8350, Acc: 0.7040
Epoch: 51, Loss: 0.8333, Acc: 0.6970
Epoch: 52, Loss: 0.8320, Acc: 0.6980
Epoch: 53, Loss: 0.8321, Acc: 0.6960
Epoch: 54, Loss: 0.8325, Acc: 0.6940
Epoch: 55, Loss: 0.8312, Acc: 0.7010
Epoch: 56, Loss: 0.8298, Acc: 0.7040
Epoch: 57, Loss: 0.8294, Acc: 0.6990
Epoch: 58, Loss: 0.8296, Acc: 0.6960
Epoch: 59, Loss: 0.8302, Acc: 0.7050
Epoch: 60, Loss: 0.8286, Acc: 0.7030
Epoch: 61, Loss: 0.8298, Acc: 0.7020
Epoch: 62, Loss: 0.8292, Acc: 0.7010
Epoch: 63, Loss: 0.8288, Acc: 0.7090
Epoch: 64, Loss: 0.8284, Acc: 0.6990
Epoch: 65, Loss: 0.8267, Acc: 0.6970
Epoch: 66, Loss: 0.8274, Acc: 0.6950
Epoch: 67, Loss: 0.8279, Acc: 0.6940
Epoch: 68, Loss: 0.8274, Acc: 0.6940
Epoch: 69, Loss: 0.8278, Acc: 0.7000
Epoch: 70, Loss: 0.8258, Acc: 0.7000
Epoch: 71, Loss: 0.8283, Acc: 0.6990
Epoch: 72, Loss: 0.8257, Acc: 0.6990
Epoch: 73, Loss: 0.8262, Acc: 0.7060
Epoch: 74, Loss: 0.8260, Acc: 0.7120
Epoch: 75, Loss: 0.8259, Acc: 0.7140
Epoch: 76, Loss: 0.8266, Acc: 0.7060
Epoch: 77, Loss: 0.8254, Acc: 0.7070
Epoch: 78, Loss: 0.8261, Acc: 0.7030
Epoch: 79, Loss: 0.8258, Acc: 0.6980
Epoch: 80, Loss: 0.8253, Acc: 0.6950
Epoch: 81, Loss: 0.8256, Acc: 0.7050
Epoch: 82, Loss: 0.8252, Acc: 0.7070
Epoch: 83, Loss: 0.8238, Acc: 0.7060
Epoch: 84, Loss: 0.8253, Acc: 0.7060
Epoch: 85, Loss: 0.8251, Acc: 0.7070
Epoch: 86, Loss: 0.8255, Acc: 0.7090
Epoch: 87, Loss: 0.8251, Acc: 0.7160
Epoch: 88, Loss: 0.8247, Acc: 0.7140
Epoch: 89, Loss: 0.8246, Acc: 0.7020
Epoch: 90, Loss: 0.8245, Acc: 0.7050
Epoch: 91, Loss: 0.8250, Acc: 0.7160
Epoch: 92, Loss: 0.8249, Acc: 0.7100
Epoch: 93, Loss: 0.8249, Acc: 0.7040
Epoch: 94, Loss: 0.8245, Acc: 0.7060
Epoch: 95, Loss: 0.8252, Acc: 0.7030
Epoch: 96, Loss: 0.8244, Acc: 0.6990
Epoch: 97, Loss: 0.8242, Acc: 0.7030
Epoch: 98, Loss: 0.8244, Acc: 0.7050
Epoch: 99, Loss: 0.8236, Acc: 0.6990
Epoch: 100, Loss: 0.8233, Acc: 0.7030
可视化效果如下:
Cora数据集中一共有七种节点,所以使用了七种颜色。从分类结果看,准确率只有0.7左右,并不是很高,所以可视化效果也不是特别好,有些类别混杂在一起了。可以考虑使用其他方法进行改进。