图卷积神经网络 (GCN)
采用了斯坦福CS224W的图神经课程相关ppt以及同济子豪兄的精讲视频,感兴趣的话可以直接去b站观看详细视频
1. 定义
图卷积神经网络是一种用于处理图结构数据的神经网络。在GCN中,图结构数据指的是由节点node(如人、地点、物体)和边edge(表示节点间的关系或连接)组成的数据。
GCN的工作原理是通过卷积操作来聚合和转换节点及其邻居节点的信息。这样做可以捕获图中的局部结构特征,类似于传统卷积神经网络(CNN)在图像处理中捕获像素的局部特征。GCN特别适合于那些节点表示和节点间关系都很重要的问题,如社交网络分析、分子结构识别、推荐系统等。
2. 计算图
2.1 Aggregate Neighbors
因为图数据的结构没有顺序和锚点等结构,所以我们不能直接把图的邻接矩阵以及图的数据直接输入到神经网络中。这里是通过消息传递的框架,来构造局部领域的计算图(聚类领域)
这里首先通过反向传播通过A节点来构造他的领域,一阶领域是第一层邻居的节点,二阶领域是邻居的邻居的节点来构造节点之间的关系。图中的层数是图的深度而不是神经网络的层数,而图中的盒子是内部的神经网络结构(自己去定义内部的神经网络),之后再通过前向传播来预测节点A,图中所示的是一个节点A的计算图。
每一个节点可以构造属于自己的计算图,每一个计算图就是一个样本(sample),并且相同层级的神经网络共享同一套参数,他们的权重是相同的
层数也多,覆盖的邻居也越多,根据“六度空间”理论,理论上图的深度到六层,就可以获取所有图的信息,我们可以堆叠多个图层来聚合更多更远的值,但有一个问题:如果我们添加了太多图层,聚合就会变得如此强烈,以至于所有的嵌入最终看起来都一样。这种现象被称为 “过度平滑”(over-smoothing),如果层数过多,就会产生真正的问题。一般来说我们选择的图卷积层数为两层或者三层来训练节点。
2.2 计算形式
我们可以看到图卷积层的计算的公式为
H
(
l
+
1
)
=
σ
(
D
~
−
1
2
A
~
D
~
−
1
2
H
(
l
)
W
(
l
)
)
H^{(l+1)} = \sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)}\right)
H(l+1)=σ(D~−21A~D~−21H(l)W(l))
我们逐步解析推出该公式,首先根据我们前面的图示,我们可以推出如下数学公式:
- 首先第0层的嵌入就是每个节点的属性特征,所以v节点在第零层的嵌入就是每个节点的属性特征
- 而之后的每一层嵌入都与前一层的嵌入有关(可以看图)
我们这里引入一个聚合函数,把v节点的所有邻居节点(上一层)的嵌入求和之后再除于连接数之后乘以该神经网络层的权重 - 而最终的Z输出我们定义为最终的节点嵌入(embedding)
ps:求平均的过程是order invariant 是顺序不变的 不管节点编号怎么换都是这个操作
我们用矩阵来表示这个公式可以看到
除于每个节点的连接数可以是邻接矩阵A
左乘节点度矩阵D的-1次方 D-1
这里给出用numpy
求节点度对角矩阵的一些代码运算以及结果显示
我们可以发现 D-1A 只能保留横向的节点度分配即只有自己的度没有考虑对方的连接数,而反过来AD-1 只保留了纵向的节点度分配即只考虑了对方的度。那么我们既要考虑自己又要考虑对方,于是我们整合成一个新的矩阵:Naively Symmetric Normalized Matrix
但是我们再求出矩阵之后计算矩阵的特征值发现最大的特征值达不到1,也就是说乘以之后他的幅值会相对之前变小,这不是我们期望的,于是就引入公式里面的:
D
−
1
2
A
~
D
~
−
1
2
{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}}
D−21A~D~−21
这时候我们发现矩阵的最大特征值有1:
也就是说我们进行矩阵相乘进行拉伸变换之后新的矩阵仍旧保持原来的方向,此时的向量为矩阵的特征向量,拉伸值为特征值。
最终的公式如图:
这里添加了一个自环,因为我除了考虑周围邻居的节点,我还想考虑自己的特征包括进我自己的影响,因此还可以添加一个自环增加一个自己作为新的邻居并附带权重
而自己的权重可以相同也可以不同,如果相同的话就是残差连接(类似于RNN),不同的话就是非恒等映射
我们可以对公式进行变换,可以得出如下新的公式:
3.代码实例
PyTorch Geometric 提供了直接实现图卷积层的 GCNConv
函数
本文采用的数据集为著名的空手道俱乐部KarateClub
数据集
!pip install torch_geometric
import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.datasets import KarateClub
# Import dataset from PyTorch Geometric
dataset = KarateClub()
# Print information
print(dataset)
print('------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
KarateClub()
------------
Number of graphs: 1
Number of features: 34
Number of classes: 4
from torch.nn import Linear
from torch_geometric.nn import GCNConv
#import the packages
class GCN(torch.nn.Module):
def __init__(self):
"""
初始构造函数
创建了一个图卷积层(GCNConv):dataset.num_features
指定了输入特征的数量(即每个节点的特征维数),
而3指定了图卷积层输出的特征维数。
创建了一个线性层(Linear),用于从图卷积层的输出得到最终的类别预测。
输入特征维数为3(与图卷积层的输出维数相匹配)
输出维数为dataset.num_classes
"""
super().__init__()
self.gcn = GCNConv(dataset.num_features, 3)
self.out = Linear(3, dataset.num_classes)
def forward(self, x, edge_index):
"""
定义前向传播函数
x是节点的特征矩阵,edge_index是描述图中边的张量
"""
h = self.gcn(x, edge_index).relu()
z = self.out(h)
return h, z
model = GCN()
print(model)
GCN(
(gcn): GCNConv(34, 3)
(out): Linear(in_features=3, out_features=4, bias=True)
)
现在我们已经定义了 GNN,让我们用 PyTorch 编写一个简单的训练循环。由于这是一项多类分类任务,我们选择常规的交叉熵损失,并使用 Adam 作为优化器。为了保持简单,我们不会实现训练/测试分离,而是专注于 GNN 如何学习。
训练循环是标准的:我们尝试预测正确的标签,然后将 GCN 的结果与存储在 data.y 中的值进行比较。误差通过交叉熵损失计算,并通过 Adam 的反向传播来微调 GCN 的权重和偏置。最后,我们每 10 个epoch打印一次损失和准确率。
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
# Calculate accuracy
def accuracy(pred_y, y):
return (pred_y == y).sum() / len(y)
# Data for animations
embeddings = []
losses = []
accuracies = []
outputs = []
# Training loop
for epoch in range(201):
# Clear gradients
optimizer.zero_grad()
# Forward pass
h, z = model(data.x, data.edge_index)
# Calculate loss function
loss = criterion(z, data.y)
# Calculate accuracy
acc = accuracy(z.argmax(dim=1), data.y)
# Compute gradients
loss.backward()
# Tune parameters
optimizer.step()
# Store data for animations
embeddings.append(h)
losses.append(loss)
accuracies.append(acc)
outputs.append(z.argmax(dim=1))
# Print metrics every 10 epochs
if epoch % 10 == 0:
print(f'Epoch {epoch:>3} | Loss: {loss:.2f} | Acc: {acc*100:.2f}%')
Epoch 0 | Loss: 1.40 | Acc: 41.18%
Epoch 10 | Loss: 1.21 | Acc: 47.06%
Epoch 20 | Loss: 1.02 | Acc: 67.65%
Epoch 30 | Loss: 0.80 | Acc: 73.53%
Epoch 40 | Loss: 0.59 | Acc: 73.53%
Epoch 50 | Loss: 0.39 | Acc: 94.12%
Epoch 60 | Loss: 0.23 | Acc: 97.06%
Epoch 70 | Loss: 0.13 | Acc: 100.00%
Epoch 80 | Loss: 0.07 | Acc: 100.00%
Epoch 90 | Loss: 0.05 | Acc: 100.00%
Epoch 100 | Loss: 0.03 | Acc: 100.00%
Epoch 110 | Loss: 0.02 | Acc: 100.00%
Epoch 120 | Loss: 0.02 | Acc: 100.00%
Epoch 130 | Loss: 0.02 | Acc: 100.00%
Epoch 140 | Loss: 0.01 | Acc: 100.00%
Epoch 150 | Loss: 0.01 | Acc: 100.00%
Epoch 160 | Loss: 0.01 | Acc: 100.00%
Epoch 170 | Loss: 0.01 | Acc: 100.00%
Epoch 180 | Loss: 0.01 | Acc: 100.00%
Epoch 190 | Loss: 0.01 | Acc: 100.00%
Epoch 200 | Loss: 0.01 | Acc: 100.00%
我们可以通过图形动画制作一个简洁的可视化图表,查看 GNN 在训练过程中的预测变化
from IPython.display import HTML
from matplotlib import animation
plt.rcParams["animation.bitrate"] = 3000
def animate(i):
G = to_networkx(data, to_undirected=True)
nx.draw_networkx(G,
pos=nx.spring_layout(G, seed=0),
with_labels=True,
node_size=800,
node_color=outputs[i],
cmap="hsv",
vmin=-2,
vmax=3,
width=0.8,
edge_color="grey",
font_size=14
)
plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
fontsize=18, pad=20)
fig = plt.figure(figsize=(12, 12))
plt.axis('off')
anim = animation.FuncAnimation(fig, animate, \
np.arange(0, 200, 10), interval=500, repeat=True)
html = HTML(anim.to_html5_video())
display(html)
通过动画,我们可以观察到不同训练周期下图的节点颜色(代表某种特征或分类)如何变化,以及损失和准确率的变化情况。
最初的预测是随机的,但一段时间后,GCN 就能完美地标出每个节点。事实上,最终的图形与我们在第一节结尾绘制的图形相同。但是,GCN 到底学到了什么呢?
通过聚合相邻节点的特征,GCN 可以学习网络中每个节点的向量表示(或嵌入embedding)。在我们的模型中,最后一层只是学习如何使用这些表示来生成最佳分类。然而,嵌入才是 GNN 的真正产物。
让我们打印一下我们的模型所学习到的嵌入。
# Print embeddings
print(f'Final embeddings = {h.shape}')
print(h)
Final embeddings = torch.Size([34, 3])
tensor([[1.9099e+00, 2.3584e+00, 7.4027e-01],
[2.6203e+00, 2.7997e+00, 0.0000e+00],
[2.2567e+00, 2.2962e+00, 6.4663e-01],
[2.0802e+00, 2.8785e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 2.9694e+00],
[0.0000e+00, 0.0000e+00, 3.3817e+00],
[0.0000e+00, 1.5008e-04, 3.4246e+00],
[1.7593e+00, 2.4292e+00, 2.4551e-01],
[1.9757e+00, 6.1032e-01, 1.8986e+00],
[1.7770e+00, 1.9950e+00, 6.7018e-01],
[0.0000e+00, 1.1683e-04, 2.9738e+00],
[1.8988e+00, 2.0512e+00, 2.6225e-01],
[1.7081e+00, 2.3618e+00, 1.9609e-01],
[1.8303e+00, 2.1591e+00, 3.5906e-01],
[2.0755e+00, 2.7468e-01, 1.9804e+00],
[1.9676e+00, 3.7185e-01, 2.0011e+00],
[0.0000e+00, 0.0000e+00, 3.4787e+00],
[1.6945e+00, 2.0350e+00, 1.9789e-01],
[1.9808e+00, 3.2633e-01, 2.1349e+00],
[1.7846e+00, 1.9585e+00, 4.8021e-01],
[2.0420e+00, 2.7512e-01, 1.9810e+00],
[1.7665e+00, 2.1357e+00, 4.0325e-01],
[1.9870e+00, 3.3886e-01, 2.0421e+00],
[2.0614e+00, 5.1042e-01, 2.4872e+00],
[2.1778e+00, 4.4730e-01, 2.0077e+00],
[3.8906e-02, 2.3443e+00, 1.9195e+00],
[3.0748e+00, 0.0000e+00, 3.0789e+00],
[3.4316e+00, 1.9716e-01, 2.5231e+00]], grad_fn=<ReluBackward0>)
我们发现嵌入不需要与特征向量具有相同的维数。我们选择将维数从 34(dataset.num_features)减少到 3,以便我们可以在三维空间可视化效果。
# Get first embedding at epoch = 0
embed = h.detach().cpu().numpy()
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d')
ax.patch.set_alpha(0)
plt.tick_params(left=False,
bottom=False,
labelleft=False,
labelbottom=False)
ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2],
s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)
plt.show()
我们可以看到 Zachary 空手道俱乐部的每个节点都带有真实标签(而非模型预测)。现在,由于 GNN 还没有训练好,所以它们都是乱七八糟的。但如果我们在训练循环的每一步都绘制这些嵌入,我们就能直观地看到 GNN 真正学到了什么。
让我们看看随着时间的推移,随着 GCN 对节点的分类能力越来越强,它们是如何演变的。
%%capture
def animate(i):
embed = embeddings[i].detach().cpu().numpy()
ax.clear()
ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2],
s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)
plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
fontsize=18, pad=40)
fig = plt.figure(figsize=(12, 12))
plt.axis('off')
ax = fig.add_subplot(projection='3d')
plt.tick_params(left=False,
bottom=False,
labelleft=False,
labelbottom=False)
anim = animation.FuncAnimation(fig, animate, \
np.arange(0, 200, 10), interval=800, repeat=True)
html = HTML(anim.to_html5_video())
display(html)
我们的图卷积网络 (GCN) 通过有效学习嵌入,将相似节点归入不同的簇。这样,最后的线性层就能轻松地将它们区分为不同的类别
Reference:
https://towardsdatascience.com/graph-convolutional-networks-introduction-to-gnns-24b3f60d6c95
https://distill.pub/2021/gnn-intro/
https://github.com/TommyZihao/zihao_course/blob/main/CS224W/1-Intro.md