【pyg】第一篇总结(基于karate的3层GCN+简单可视化,额外补充了cora)

这篇博客通过Karate空手道俱乐部和Cora数据集,详细介绍了如何构建和可视化3层GCN模型。在未训练前,随机初始化的节点嵌入已显现出社区结构。经过训练,模型成功将节点分类,并利用PyTorch几何库简化了实现过程。
摘要由CSDN通过智能技术生成

目录

Karate空手道俱乐部

数据集dataset统计查看

单张图graph数据data统计查看

可视化数据单张图数据

GCN3层模型定义

嵌入表示 可视化

半监督节点分类训练输出+可视化

cora

数据集dataset统计查看

单张图graph数据data统计查看

可视化数据单张图数据

GCN模型定义

嵌入表示 可视化

训练+可视化输出

参考


Karate空手道俱乐部

数据集dataset统计查看

代码

from torch_geometric.datasets import KarateClub

# 数据集查看
dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

结果

Dataset: KarateClub():
======================
Number of graphs: 1
Number of features: 34
Number of classes: 4

单张图graph数据data统计查看

代码

# 数据+属性信息查看
data = dataset[0]  # Get the first graph object.


print(data)
结果
Data(edge_index=[2, 156], train_mask=[34], x=[34, 34], y=[34])
【我们可以看到,这个数据对象包含4个属性】
1)edge_index属性包含关于图连通性的信息,即每个边的源节点和目标节点索引的元组。
2)节点特征表示为x(34个节点每个节点分配一个34-dim的特征向量),
3)节点标签表示为y(每个节点都被精确分配到一个类)。
4)还有一个额外的属性train_mask,它描述了我们已经知道哪些节点的社区分配。

总的来说,我们只知道4个节点的ground-truth标签(每个节点一个)

其他属性查看代码

print('==============================================================')

# 数据对象还提供一些实用函数来推断底层图的一些基本属性
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

结果

==============================================================
Number of nodes: 34
Number of edges: 156
Average node degree: 4.59
Number of training nodes: 4
Training node label rate: 0.12
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True

信息查看图链接情况代码

通过打印edge_index,我们可以进一步理解PyG如何在内部表示图的连通性。
我们可以看到,对于每条边,edge_index保存一个由两个节点索引组成的元组,其中第一个值描述源节点的节点索引,第二个值描述边的目标节点的节点索引。
这种表示被称为COO格式(坐标格式),通常用于表示稀疏矩阵。

PyG不是用密集表示A来保存邻接信息,而是稀疏表示图,即只保存A中的项非零的坐标/值。
我们可以通过将其转换为networkx库格式来进一步可视化图
# 详细查看邻接矩阵情况
from IPython.display import Javascript,display  # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

edge_index = data.edge_index
print(edge_index.t())

结果

<IPython.core.display.Javascript object>
tensor([[ 0,  1],
        [ 0,  2],
        [ 0,  3],
        [ 0,  4],
        [ 0,  5],
        [ 0,  6],
        [ 0,  7],
        [ 0,  8],
        [ 0, 10],
        [ 0, 11],
        [ 0, 12],
        [ 0, 13],
        [ 0, 17],
        [ 0, 19],
        [ 0, 21],
        [ 0, 31],
        [ 1,  0],
        [ 1,  2],
        [ 1,  3],
        [ 1,  7],
        [ 1, 13],
        [ 1, 17],
        [ 1, 19],
        [ 1, 21],
        [ 1, 30],
        [ 2,  0],
        [ 2,  1],
        [ 2,  3],
        [ 2,  7],
        [ 2,  8],
        [ 2,  9],
        [ 2, 13],
        [ 2, 27],
        [ 2, 28],
        [ 2, 32],
        [ 3,  0],
        [ 3,  1],
        [ 3,  2],
        [ 3,  7],
        [ 3, 12],
        [ 3, 13],
        [ 4,  0],
        [ 4,  6],
        [ 4, 10],
        [ 5,  0],
        [ 5,  6],
        [ 5, 10],
        [ 5, 16],
        [ 6,  0],
        [ 6,  4],
        [ 6,  5],
        [ 6, 16],
        [ 7,  0],
        [ 7,  1],
        [ 7,  2],
        [ 7,  3],
        [ 8,  0],
        [ 8,  2],
        [ 8, 30],
        [ 8, 32],
        [ 8, 33],
        [ 9,  2],
        [ 9, 33],
        [10,  0],
        [10,  4],
        [10,  5],
        [11,  0],
        [12,  0],
        [12,  3],
        [13,  0],
        [13,  1],
        [13,  2],
        [13,  3],
        [13, 33],
        [14, 32],
        [14, 33],
        [15, 32],
        [15, 33],
        [16,  5],
        [16,  6],
        [17,  0],
        [17,  1],
        [18, 32],
        [18, 33],
        [19,  0],
        [19,  1],
        [19, 33],
        [20, 32],
        [20, 33],
        [21,  0],
        [21,  1],
        [22, 32],
        [22, 33],
        [23, 25],
        [23, 27],
        [23, 29],
        [23, 32],
        [23, 33],
        [24, 25],
        [24, 27],
        [24, 31],
        [25, 23],
        [25, 24],
        [25, 31],
        [26, 29],
        [26, 33],
        [27,  2],
        [27, 23],
        [27, 24],
        [27, 33],
        [28,  2],
        [28, 31],
        [28, 33],
        [29, 23],
        [29, 26],
        [29, 32],
        [29, 33],
        [30,  1],
        [30,  8],
        [30, 32],
        [30, 33],
        [31,  0],
        [31, 24],
        [31, 25],
        [31, 28],
        [31, 32],
        [31, 33],
        [32,  2],
        [32,  8],
        [32, 14],
        [32, 15],
        [32, 18],
        [32, 20],
        [32, 22],
        [32, 23],
        [32, 29],
        [32, 30],
        [32, 31],
        [32, 33],
        [33,  8],
        [33,  9],
        [33, 13],
        [33, 14],
        [33, 15],
        [33, 18],
        [33, 19],
        [33, 20],
        [33, 22],
        [33, 23],
        [33, 26],
        [33, 27],
        [33, 28],
        [33, 29],
        [33, 30],
        [33, 31],
        [33, 32]])

可视化数据单张图数据

代码

from torch_geometric.utils import to_networkx
import torch
import networkx as nx
import matplotlib.pyplot as plt


def visualize(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])

    if torch.is_tensor(h):
        h = h.detach().cpu().numpy()
        plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
        if epoch is not None and loss is not None:
            plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    else:
        nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                         node_color=color, cmap="Set2")
    plt.show()


G = to_networkx(data, to_undirected=True)
visualize(G, color=data.y)

结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

静静喜欢大白

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值