PyTorch图神经网络实践(四)Node2Vec节点分类及其可视化

关于Node2Vec的介绍有很多了,这里就不细述。本文主要是介绍如何用PyTorch Geometric快速实现Node2Vec节点分类,并对其结果进行可视化。

整个过程包含四个步骤:

  • 导入图数据(这里以Cora为例)
  • 创建Node2Vec模型
  • 训练和测试数据
  • TSNE降维后可视化

完整代码如下:

import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Node2Vec

dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,
                 context_size=10, walks_per_node=10, num_negative_samples=1,
                 sparse=True).to(device)
loader = model.loader(batch_size=128, shuffle=True, num_workers=4)
optimizer = torch.optim.SparseAdam(model.parameters(), lr=0.01)


def train():
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in loader:
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


@torch.no_grad()
def test():
    model.eval()
    z = model()
    acc = model.test(z[data.train_mask], data.y[data.train_mask],
                     z[data.test_mask], data.y[data.test_mask], max_iter=150)
    return acc


for epoch in range(1, 101):
    loss = train()
    acc = test()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Acc: {acc:.4f}')


@torch.no_grad()
def plot_points(colors):
    model.eval()
    z = model(torch.arange(data.num_nodes, device=device))
    z = TSNE(n_components=2).fit_transform(z.cpu().numpy())
    y = data.y.cpu().numpy()

    plt.figure(figsize=(8, 8))
    for i in range(dataset.num_classes):
        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])
    plt.axis('off')
    plt.show()


colors = ['#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700']
plot_points(colors)

输出结果如下:

Epoch: 01, Loss: 8.0489, Acc: 0.1780
Epoch: 02, Loss: 6.0371, Acc: 0.2000
Epoch: 03, Loss: 4.9573, Acc: 0.2190
Epoch: 04, Loss: 4.1341, Acc: 0.2540
Epoch: 05, Loss: 3.4723, Acc: 0.2920
Epoch: 06, Loss: 2.9627, Acc: 0.3250
Epoch: 07, Loss: 2.5404, Acc: 0.3550
Epoch: 08, Loss: 2.2190, Acc: 0.3840
Epoch: 09, Loss: 1.9515, Acc: 0.4200
Epoch: 10, Loss: 1.7386, Acc: 0.4470
Epoch: 11, Loss: 1.5671, Acc: 0.4670
Epoch: 12, Loss: 1.4274, Acc: 0.4890
Epoch: 13, Loss: 1.3167, Acc: 0.5060
Epoch: 14, Loss: 1.2284, Acc: 0.5340
Epoch: 15, Loss: 1.1606, Acc: 0.5530
Epoch: 16, Loss: 1.1018, Acc: 0.5580
Epoch: 17, Loss: 1.0573, Acc: 0.5850
Epoch: 18, Loss: 1.0241, Acc: 0.5970
Epoch: 19, Loss: 0.9943, Acc: 0.6070
Epoch: 20, Loss: 0.9701, Acc: 0.6080
Epoch: 21, Loss: 0.9523, Acc: 0.6230
Epoch: 22, Loss: 0.9342, Acc: 0.6290
Epoch: 23, Loss: 0.9200, Acc: 0.6360
Epoch: 24, Loss: 0.9104, Acc: 0.6460
Epoch: 25, Loss: 0.9003, Acc: 0.6350
Epoch: 26, Loss: 0.8936, Acc: 0.6440
Epoch: 27, Loss: 0.8859, Acc: 0.6590
Epoch: 28, Loss: 0.8776, Acc: 0.6570
Epoch: 29, Loss: 0.8730, Acc: 0.6590
Epoch: 30, Loss: 0.8696, Acc: 0.6660
Epoch: 31, Loss: 0.8658, Acc: 0.6740
Epoch: 32, Loss: 0.8609, Acc: 0.6740
Epoch: 33, Loss: 0.8586, Acc: 0.6750
Epoch: 34, Loss: 0.8559, Acc: 0.6770
Epoch: 35, Loss: 0.8524, Acc: 0.6800
Epoch: 36, Loss: 0.8516, Acc: 0.6880
Epoch: 37, Loss: 0.8494, Acc: 0.6820
Epoch: 38, Loss: 0.8481, Acc: 0.6770
Epoch: 39, Loss: 0.8464, Acc: 0.6800
Epoch: 40, Loss: 0.8435, Acc: 0.6850
Epoch: 41, Loss: 0.8423, Acc: 0.6890
Epoch: 42, Loss: 0.8401, Acc: 0.6870
Epoch: 43, Loss: 0.8392, Acc: 0.6900
Epoch: 44, Loss: 0.8390, Acc: 0.6850
Epoch: 45, Loss: 0.8365, Acc: 0.6880
Epoch: 46, Loss: 0.8363, Acc: 0.6920
Epoch: 47, Loss: 0.8354, Acc: 0.6990
Epoch: 48, Loss: 0.8345, Acc: 0.6970
Epoch: 49, Loss: 0.8352, Acc: 0.6970
Epoch: 50, Loss: 0.8350, Acc: 0.7040
Epoch: 51, Loss: 0.8333, Acc: 0.6970
Epoch: 52, Loss: 0.8320, Acc: 0.6980
Epoch: 53, Loss: 0.8321, Acc: 0.6960
Epoch: 54, Loss: 0.8325, Acc: 0.6940
Epoch: 55, Loss: 0.8312, Acc: 0.7010
Epoch: 56, Loss: 0.8298, Acc: 0.7040
Epoch: 57, Loss: 0.8294, Acc: 0.6990
Epoch: 58, Loss: 0.8296, Acc: 0.6960
Epoch: 59, Loss: 0.8302, Acc: 0.7050
Epoch: 60, Loss: 0.8286, Acc: 0.7030
Epoch: 61, Loss: 0.8298, Acc: 0.7020
Epoch: 62, Loss: 0.8292, Acc: 0.7010
Epoch: 63, Loss: 0.8288, Acc: 0.7090
Epoch: 64, Loss: 0.8284, Acc: 0.6990
Epoch: 65, Loss: 0.8267, Acc: 0.6970
Epoch: 66, Loss: 0.8274, Acc: 0.6950
Epoch: 67, Loss: 0.8279, Acc: 0.6940
Epoch: 68, Loss: 0.8274, Acc: 0.6940
Epoch: 69, Loss: 0.8278, Acc: 0.7000
Epoch: 70, Loss: 0.8258, Acc: 0.7000
Epoch: 71, Loss: 0.8283, Acc: 0.6990
Epoch: 72, Loss: 0.8257, Acc: 0.6990
Epoch: 73, Loss: 0.8262, Acc: 0.7060
Epoch: 74, Loss: 0.8260, Acc: 0.7120
Epoch: 75, Loss: 0.8259, Acc: 0.7140
Epoch: 76, Loss: 0.8266, Acc: 0.7060
Epoch: 77, Loss: 0.8254, Acc: 0.7070
Epoch: 78, Loss: 0.8261, Acc: 0.7030
Epoch: 79, Loss: 0.8258, Acc: 0.6980
Epoch: 80, Loss: 0.8253, Acc: 0.6950
Epoch: 81, Loss: 0.8256, Acc: 0.7050
Epoch: 82, Loss: 0.8252, Acc: 0.7070
Epoch: 83, Loss: 0.8238, Acc: 0.7060
Epoch: 84, Loss: 0.8253, Acc: 0.7060
Epoch: 85, Loss: 0.8251, Acc: 0.7070
Epoch: 86, Loss: 0.8255, Acc: 0.7090
Epoch: 87, Loss: 0.8251, Acc: 0.7160
Epoch: 88, Loss: 0.8247, Acc: 0.7140
Epoch: 89, Loss: 0.8246, Acc: 0.7020
Epoch: 90, Loss: 0.8245, Acc: 0.7050
Epoch: 91, Loss: 0.8250, Acc: 0.7160
Epoch: 92, Loss: 0.8249, Acc: 0.7100
Epoch: 93, Loss: 0.8249, Acc: 0.7040
Epoch: 94, Loss: 0.8245, Acc: 0.7060
Epoch: 95, Loss: 0.8252, Acc: 0.7030
Epoch: 96, Loss: 0.8244, Acc: 0.6990
Epoch: 97, Loss: 0.8242, Acc: 0.7030
Epoch: 98, Loss: 0.8244, Acc: 0.7050
Epoch: 99, Loss: 0.8236, Acc: 0.6990
Epoch: 100, Loss: 0.8233, Acc: 0.7030

可视化效果如下:

在这里插入图片描述

Cora数据集中一共有七种节点,所以使用了七种颜色。从分类结果看,准确率只有0.7左右,并不是很高,所以可视化效果也不是特别好,有些类别混杂在一起了。可以考虑使用其他方法进行改进。

参考文献
node2vec:算法原理,实现和应用
TSNE——目前最好的降维方法

  • 6
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值