引言
图神经网络已经成功地应用于许多节点或边的预测任务,然而,在超大图上进行图神经网络的训练仍然具有挑战。普通的基于SGD的图神经网络的训练方法,要么面临着随着图神经网络层数增加,计算成本呈指数增长的问题,要么面临着保存整个图的信息和每一层每个节点的表征到内存(显存)而消耗巨大内存(显存)空间的问题。虽然已经有一些论文提出了无需保存整个图的信息和每一层每个节点的表征到GPU内存(显存)的方法,但这些方法可能会损失预测精度或者对提高内存的利用率并不明显。于是论文Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Network提出了一种新的图神经网络的训练方法。
Cluster-GCN方法简单概括
为了解决普通训练方法无法训练超大图的问题,Cluster-GCN论文提出:
- 利用图节点聚类算法将一个图的节点划分为 c c c个簇,每一次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练。
- 由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率。
- 每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会有batch内类别分布偏差过大的问题。
- 基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,进而可以达到更高的精度。
节点表征学习回顾
给定一个图
G
=
(
V
,
E
,
A
)
G=(\mathcal{V}, \mathcal{E}, A)
G=(V,E,A),它由
N
=
∣
V
∣
N=|\mathcal{V}|
N=∣V∣个节点和
∣
E
∣
|\mathcal{E}|
∣E∣条边组成,其邻接矩阵记为
A
A
A,其节点属性记为
X
∈
R
N
×
F
X \in \mathbb{R}^{N \times F}
X∈RN×F,
F
F
F表示节点属性的维度。一个
L
L
L层的图卷积神经网络由
L
L
L个图卷积层组成,每一层都通过聚合邻接节点的上一层的表征来生成中心节点的当前层的表征:
Z
(
l
+
1
)
=
A
′
X
(
l
)
W
(
l
)
,
X
(
l
+
1
)
=
σ
(
Z
(
l
+
1
)
)
(1)
Z^{(l+1)}=A^{\prime} X^{(l)} W^{(l)}, X^{(l+1)}=\sigma\left(Z^{(l+1)}\right) \tag{1}
Z(l+1)=A′X(l)W(l),X(l+1)=σ(Z(l+1))(1)
其中
X
(
l
)
∈
R
N
×
F
l
X^{(l)} \in \mathbb{R}^{N \times F_{l}}
X(l)∈RN×Fl表示第
l
l
l层
N
N
N个节点的表征,并且有
X
(
0
)
=
X
X^{(0)}=X
X(0)=X。
A
′
A^{\prime}
A′是归一化和规范化后的邻接矩阵,
W
(
l
)
∈
R
F
l
×
F
l
+
1
W^{(l)} \in \mathbb{R}^{F_{l} \times F_{l+1}}
W(l)∈RFl×Fl+1是权重矩阵,也就是要训练的参数。为了简单起见,我们假设所有层的表征维度都是一样的,即
(
F
1
=
⋯
=
F
L
=
F
)
\left(F_{1}=\cdots=F_{L}=F\right)
(F1=⋯=FL=F)。激活函数
σ
(
⋅
)
\sigma(\cdot)
σ(⋅)通常被设定为ReLU
。
当图神经网络应用于半监督节点分类任务时,训练的目标是通过最小化损失函数来学习公式(1)中的权重矩阵:
L
=
1
∣
Y
L
∣
∑
i
∈
Y
L
loss
(
y
i
,
z
i
L
)
(2)
\mathcal{L}=\frac{1}{\left|\mathcal{Y}_{L}\right|} \sum_{i \in \mathcal{Y}_{L}} \operatorname{loss}\left(y_{i}, z_{i}^{L}\right) \tag{2}
L=∣YL∣1i∈YL∑loss(yi,ziL)(2)
其中,
Y
L
\mathcal{Y}_{L}
YL是节点类别;
z
i
(
L
)
z_{i}^{(L)}
zi(L)是
Z
(
L
)
Z^{(L)}
Z(L)的第
i
i
i行,表示对节点
i
i
i的预测,节点
i
i
i的真实类别为
y
i
y_{i}
yi。
Cluster-GCN方法详细分析
以往的训练方法需要同时计算所有节点的表征以及训练集中所有节点的损失产生的梯度(后文我们直接称为完整梯度)。这种训练方式需要非常巨大的计算开销和内存(显存)开销:在内存(显存)方面,计算公式(2)的完整梯度需要存储所有的节点表征矩阵 { Z ( l ) } l = 1 L \left\{Z^{(l)}\right\}_{l=1}^{L} {Z(l)}l=1L,这需要 O ( N F L ) O(N F L) O(NFL)的空间;在收敛速度方面,由于神经网络在每个epoch中只更新一次,所以训练需要更多的epoch才能达到收敛。
最近的一些工作证明,采用mini-batch SGD的方式训练,可以提高图神经网络的训练速度并减少内存(显存)需求。在参数更新中,SGD不需要计算完整梯度,而只需要基于mini-batch计算部分梯度。我们使用 B ⊆ [ N ] \mathcal{B} \subseteq[N] B⊆[N]来表示一个batch,其大小为 b = ∣ B ∣ b=|\mathcal{B}| b=∣B∣。SGD的每一步都将计算梯度估计值 1 ∣ B ∣ ∑ i ∈ B ∇ loss ( y i , z i ( L ) ) \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \nabla \operatorname{loss}\left(y_{i}, z_{i}^{(L)}\right) ∣B∣1∑i∈B∇loss(yi,zi(L))来进行参数更新。尽管在epoches数量相同的情况下,采用SGD方式进行训练,收敛速度可以更快,但此种训练方式会引入额外的时间开销,这使得相比于全梯度下降的训练方式,此种训练方式每个epoch的时间开销要大得多。
为什么采用最简单的mini-batch SGD方式进行训练,每个epoch需要的时间更多?我们将节点 i i i的梯度的计算表示为 ∇ loss ( y i , z i ( L ) ) \nabla \operatorname{loss}\left(y_{i}, z_{i}^{(L)}\right) ∇loss(yi,zi(L)),它依赖于节点 i i i的 L L L层的表征,而节点 i i i的非第 0 0 0层的表征都依赖于各自邻接节点的前一层的表征,这被称为邻域扩展。假设一个图神经网络有 L + 1 L+1 L+1层,节点的平均的度为 d d d。为了得到节点 i i i的梯度,平均我们需要聚合图上 O ( d L ) O\left(d^{L}\right) O(dL)的节点的表征。也就是说,我们需要获取节点的距离为 k ( k = 1 , ⋯ , L ) k(k=1, \cdots, L) k(k=1,⋯,L)的邻接节点的信息来进行一次参数更新。由于要与权重矩阵 W ( l ) W^{(l)} W(l)相乘,所以计算任意节点表征的时间开销是 O ( F 2 ) O\left(F^{2}\right) O(F2)。所以平均来说,一个节点的梯度的计算需要 O ( d L F 2 ) O\left(d^{L} F^{2}\right) O(dLF2)的时间。
节点表征的利用率可以反映出计算的效率。考虑到一个batch有多个节点,时间与空间复杂度的计算就不是上面那样简单了,因为不同的节点同样距离远的邻接节点可以是重叠的,于是计算表征的次数可以小于最坏的情况 O ( b d L ) O\left(b d^{L}\right) O(bdL)。为了反映mini-batch SGD的计算效率,Cluster-GCN论文提出了"表征利用率"的概念来描述计算效率。在训练过程中,如果节点 i i i在 l l l层的表征 z i ( l ) z_{i}^{(l)} zi(l)被计算并在 l + 1 l+1 l+1层的表征计算中被重复使用 u u u次,那么我们说 z i ( l ) z_{i}^{(l)} zi(l)的表征利用率为 u u u。对于随机抽样的mini-batch SGD, u u u非常小,因为图通常是大且稀疏的。假设 u u u是一个小常数(节点间同样距离的邻接节点重叠率小),那么mini-batch SGD的训练方式对每个batch需要计算 O ( b d L ) O\left(b d^{L}\right) O(bdL)的表征,于是每次参数更新需要 O ( b d L F 2 ) O\left(b d^{L} F^{2}\right) O(bdLF2)的时间,每个epoch需要 O ( N d L F 2 ) O\left(N d^{L} F^{2}\right) O(NdLF2)的时间,这被称为邻域扩展问题。
相反的是,全梯度下降训练具有最大的表征利用率——每个节点表征将在上一层被重复使用平均节点度次。因此,全梯度下降法在每个epoch中只需要计算
O
(
N
L
)
O(N L)
O(NL)的表征,这意味着平均下来只需要
O
(
L
)
O(L)
O(L)的表征计算就可以获得一个节点的梯度。
简单的Cluster-GCN方法
Cluster-GCN方法是由这样的问题驱动的:我们能否找到一种将节点分成多个batch的方式,对应地将图划分成多个子图,使得表征利用率最大?我们通过将表征利用率的概念与图节点聚类的目标联系起来来回答这个问题。
考虑到在每个batch中,我们计算一组节点(记为 B \mathcal{B} B)从第 1 1 1层到第 L L L层的表征。由于图神经网络每一层的计算都使用相同的子图 A B , B A_{\mathcal{B}, \mathcal{B}} AB,B( B \mathcal{B} B内部的边),所以表征利用率就是这个batch内边的数量,记为 ∥ A B , B ∥ 0 \left\|A_{\mathcal{B}, \mathcal{B}}\right\|_{0} ∥AB,B∥0。因此,为了最大限度地提高表征利用率,理想的划分batch的结果是,batch内的边尽可能多,batch之间的边尽可能少。基于这一点,我们将SGD图神经网络训练的效率与图聚类算法联系起来。
现在我们正式学习Cluster-GCN方法。对于一个图
G
G
G,我们将其节点划分为
c
c
c个簇:
V
=
[
V
1
,
⋯
V
c
]
\mathcal{V}=\left[\mathcal{V}_{1}, \cdots \mathcal{V}_{c}\right]
V=[V1,⋯Vc],其中
V
t
\mathcal{V}_{t}
Vt由第
t
t
t个簇中的节点组成,对应的我们有
c
c
c个子图:
KaTeX parse error: Undefined control sequence: \notag at position 161: …right\}\right] \̲n̲o̲t̲a̲g̲ ̲
其中
E
t
\mathcal{E}_{t}
Et只由
V
t
\mathcal{V}_{t}
Vt中的节点之间的边组成。经过节点重组,邻接矩阵被划分为大小为
c
2
c^{2}
c2的块矩阵,如下所示
A
=
A
ˉ
+
Δ
=
[
A
11
⋯
A
1
c
⋮
⋱
⋮
A
c
1
⋯
A
c
c
]
(4)
A=\bar{A}+\Delta=\left[\begin{array}{ccc} A_{11} & \cdots & A_{1 c} \\ \vdots & \ddots & \vdots \\ A_{c 1} & \cdots & A_{c c} \end{array}\right] \tag{4}
A=Aˉ+Δ=⎣⎢⎡A11⋮Ac1⋯⋱⋯A1c⋮Acc⎦⎥⎤(4)
其中
A
ˉ
=
[
A
11
⋯
0
⋮
⋱
⋮
0
⋯
A
c
c
]
,
Δ
=
[
0
⋯
A
1
c
⋮
⋱
⋮
A
c
1
⋯
0
]
(5)
\bar{A}=\left[\begin{array}{ccc} A_{11} & \cdots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \cdots & A_{c c} \end{array}\right], \Delta=\left[\begin{array}{ccc} 0 & \cdots & A_{1 c} \\ \vdots & \ddots & \vdots \\ A_{c 1} & \cdots & 0 \end{array}\right] \tag{5}
Aˉ=⎣⎢⎡A11⋮0⋯⋱⋯0⋮Acc⎦⎥⎤,Δ=⎣⎢⎡0⋮Ac1⋯⋱⋯A1c⋮0⎦⎥⎤(5)
其中,对角线上的块
A
t
t
A_{t t}
Att是大小为
∣
V
t
∣
×
∣
V
t
∣
\left|\mathcal{V}_{t}\right| \times\left|\mathcal{V}_{t}\right|
∣Vt∣×∣Vt∣的邻接矩阵,它由
G
t
G_{t}
Gt内部的边构成。
A
ˉ
\bar{A}
Aˉ是图
G
ˉ
\bar{G}
Gˉ的邻接矩阵。
A
s
t
A_{s t}
Ast由两个簇
V
s
\mathcal{V}_{s}
Vs和
V
t
\mathcal{V}_{t}
Vt之间的边构成。
Δ
\Delta
Δ是由
A
A
A的所有非对角线块组成的矩阵。同样,我们可以根据
[
V
1
,
⋯
,
V
c
]
\left[\mathcal{V}_{1}, \cdots, \mathcal{V}_{c}\right]
[V1,⋯,Vc]划分节点表征矩阵
X
X
X和类别向量
Y
Y
Y,得到
[
X
1
,
⋯
,
X
c
]
\left[X_{1}, \cdots, X_{c}\right]
[X1,⋯,Xc]和
[
Y
1
,
⋯
,
Y
c
]
\left[Y_{1}, \cdots, Y_{c}\right]
[Y1,⋯,Yc],其中
X
t
X_{t}
Xt和
Y
t
Y_{t}
Yt分别由
V
t
V_{t}
Vt中节点的表征和类别组成。
接下来我们用块对角线邻接矩阵
A
ˉ
\bar{A}
Aˉ去近似邻接矩阵
A
A
A,这样做的好处是,完整的损失函数(公示(2))可以根据batch分解成多个部分之和。以
A
ˉ
′
\bar{A}^{\prime}
Aˉ′表示归一化后的
A
ˉ
\bar{A}
Aˉ,最后一层节点表征矩阵可以做如下的分解:
Z
(
L
)
=
A
ˉ
′
σ
(
A
ˉ
′
σ
(
⋯
σ
(
A
ˉ
′
X
W
(
0
)
)
W
(
1
)
)
⋯
)
W
(
L
−
1
)
=
[
A
ˉ
11
′
σ
(
A
ˉ
11
′
σ
(
⋯
σ
(
A
ˉ
11
′
X
1
W
(
0
)
)
W
(
1
)
)
⋯
)
W
(
L
−
1
)
⋮
A
ˉ
c
c
′
σ
(
A
ˉ
c
c
′
σ
(
⋯
σ
(
A
ˉ
c
c
′
X
c
W
(
0
)
)
W
(
1
)
)
⋯
)
W
(
L
−
1
)
]
(6)
\begin{aligned} Z^{(L)} &=\bar{A}^{\prime} \sigma\left(\bar{A}^{\prime} \sigma\left(\cdots \sigma\left(\bar{A}^{\prime} X W^{(0)}\right) W^{(1)}\right) \cdots\right) W^{(L-1)} \\ &=\left[\begin{array}{c} \bar{A}_{11}^{\prime} \sigma\left(\bar{A}_{11}^{\prime} \sigma\left(\cdots \sigma\left(\bar{A}_{11}^{\prime} X_{1} W^{(0)}\right) W^{(1)}\right) \cdots\right) W^{(L-1)} \\ \vdots \\ \bar{A}_{c c}^{\prime} \sigma\left(\bar{A}_{c c}^{\prime} \sigma\left(\cdots \sigma\left(\bar{A}_{c c}^{\prime} X_{c} W^{(0)}\right) W^{(1)}\right) \cdots\right) W^{(L-1)} \end{array}\right] \end{aligned} \tag{6}
Z(L)=Aˉ′σ(Aˉ′σ(⋯σ(Aˉ′XW(0))W(1))⋯)W(L−1)=⎣⎢⎡Aˉ11′σ(Aˉ11′σ(⋯σ(Aˉ11′X1W(0))W(1))⋯)W(L−1)⋮Aˉcc′σ(Aˉcc′σ(⋯σ(Aˉcc′XcW(0))W(1))⋯)W(L−1)⎦⎥⎤(6)
由于
A
ˉ
\bar{A}
Aˉ是块对角形式(
A
ˉ
t
t
′
\bar{A}_{t t}^{\prime}
Aˉtt′是
A
ˉ
′
\bar{A}^{\prime}
Aˉ′的对角线上的块),于是损失函数可以分解为
L
A
ˉ
′
=
∑
t
∣
V
t
∣
N
L
A
ˉ
t
t
′
and
L
A
ˉ
t
t
′
=
1
∣
V
t
∣
∑
i
∈
V
t
loss
(
y
i
,
z
i
(
L
)
)
(7)
\mathcal{L}_{\bar{A}^{\prime}}=\sum_{t} \frac{\left|\mathcal{V}_{t}\right|}{N} \mathcal{L}_{\bar{A}_{t t}^{\prime}} \text { and } \mathcal{L}_{\bar{A}_{t t}^{\prime}}=\frac{1}{\left|\mathcal{V}_{t}\right|} \sum_{i \in \mathcal{V}_{t}} \operatorname{loss}\left(y_{i}, z_{i}^{(L)}\right) \tag{7}
LAˉ′=t∑N∣Vt∣LAˉtt′ and LAˉtt′=∣Vt∣1i∈Vt∑loss(yi,zi(L))(7)
基于公式(6)和公式(7),在训练的每一步中,Cluster-GCN首先采样一个簇
V
t
\mathcal{V}_{t}
Vt,然后根据
L
A
ˉ
′
t
t
\mathcal{L}_{{\bar{A}^{\prime}}_{tt}}
LAˉ′tt的梯度进行参数更新。这种训练方式,只需要用到子图
A
t
t
A_{t t}
Att,
X
t
X_{t}
Xt,
Y
t
Y_{t}
Yt以及神经网络权重矩阵
{
W
(
l
)
}
l
=
1
L
\left\{W^{(l)}\right\}_{l=1}^{L}
{W(l)}l=1L。 实际中,主要的计算开销在神经网络前向过程中的矩阵乘法运算(公式(6)的一个行)和梯度反向传播。
我们使用图节点聚类算法来划分图。图节点聚类算法将图节点分成多个簇,划分结果是簇内边的数量远多于簇间边的数量。如前所述,每个batch的表征利用率相当于簇内边的数量。直观地说,每个节点和它的邻接节点大部分情况下都位于同一个簇中,因此** L L L跳(L-hop)远的邻接节点大概率仍然在同一个簇中**。由于我们用块对角线近似邻接矩阵 A ˉ \bar{A} Aˉ代替邻接矩阵 A A A,产生的误差与簇间的边的数量 Δ \Delta Δ成正比,所以簇间的边越少越好。综上所述,使用图节点聚类算法对图节点划分多个簇的结果,正是我们希望得到的。
训练深层GCNs的问题
以往尝试训练更深的GCN的研究似乎表明,增加更多的层是没有帮助的。然而,那些研究的实验使用的图太小,所以结论可能并不正确。例如,其中有一项研究只使用了一个只有几百个训练节点的图,由于节点数量过少,很容易出现过拟合的问题。此外,加深GCN神经网络层数后,训练变得很困难,因为层数多了之后前面的信息可能无法传到后面。有的研究采用了一种类似于残差连接的技术,使模型能够将前一层的信息直接传到下一层。具体来说,他们修改了公式(1),将第
l
l
l层的表征添加到下一层,如下所示
X
(
l
+
1
)
=
σ
(
A
′
X
(
l
)
W
(
l
)
)
+
X
(
l
)
(8)
X^{(l+1)}=\sigma\left(A^{\prime} X^{(l)} W^{(l)}\right)+X^{(l)} \tag{8}
X(l+1)=σ(A′X(l)W(l))+X(l)(8)
在这里,我们提出了另一种简单的技术来改善深层GCN神经网络的训练。在原始的GCN的设置里,每个节点都聚合邻接节点在上一层的表征。然而,在深层GCN的设置里,该策略可能不适合,因为它没有考虑到层数的问题。直观地说,近距离的邻接节点应该比远距离的的邻接节点贡献更大。因此,Cluster-GCN提出一种技术来更好地解决这个问题。其主要思想是放大GCN每一层中使用的邻接矩阵
A
A
A的对角线部分。通过这种方式,我们在GCN的每一层的聚合中对来自上一层的表征赋予更大的权重。这可以通过给
A
ˉ
\bar{A}
Aˉ加上一个单位矩阵
I
I
I来实现,如下所示,
X
(
l
+
1
)
=
σ
(
(
A
′
+
I
)
X
(
l
)
W
(
l
)
)
(9)
X^{(l+1)}=\sigma\left(\left(A^{\prime}+I\right) X^{(l)} W^{(l)}\right) \tag{9}
X(l+1)=σ((A′+I)X(l)W(l))(9)
虽然公式(9)似乎是合理的,但对所有节点使用相同的权重而不考虑其邻居的数量可能不合适。此外,它可能会受到数值不稳定的影响,因为当使用更多的层时,数值会呈指数级增长。因此,Cluster-GCN方法提出了一个修改版的公式(9),以更好地保持邻接节点信息和数值范围。首先给原始的
A
A
A添加一个单位矩阵
I
I
I,并进行归一化处理
A
~
=
(
D
+
I
)
−
1
(
A
+
I
)
(10)
\tilde{A}=(D+I)^{-1}(A+I) \tag{10}
A~=(D+I)−1(A+I)(10)
然后考虑,
X
(
l
+
1
)
=
σ
(
(
A
~
+
λ
diag
(
A
~
)
)
X
(
l
)
W
(
l
)
)
(11)
X^{(l+1)}=\sigma\left((\tilde{A}+\lambda \operatorname{diag}(\tilde{A})) X^{(l)} W^{(l)}\right) \tag{11}
X(l+1)=σ((A~+λdiag(A~))X(l)W(l))(11)
Cluster-GCN实践
from torch_geometric.datasets import Reddit
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler
dataset = Reddit('../dataset/Reddit')
data = dataset[0]
print(dataset.num_classes)
print(data.num_nodes)
print(data.num_edges)
print(data.num_features)
# 41
# 232965
# 114615873
# 602
train_loader
,此数据加载器遵循Cluster-GCN提出的方法,图节点被聚类划分成多个簇,此数据加载器返回的一个batch由多个簇组成。
subgraph_loader
,使用此数据加载器不对图节点聚类,计算一个batch中的节点的表征需要计算该batch中的所有节点的距离从
0
0
0到
L
L
L的邻居节点。
图神经网络的构建
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Net, self).__init__()
self.convs = ModuleList(
[SAGEConv(in_channels, 128),
SAGEConv(128, out_channels)])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return F.log_softmax(x, dim=-1)
def inference(self, x_all):
pbar = tqdm(total=x_all.size(0) * len(self.convs))
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
for i, conv in enumerate(self.convs):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = conv((x, x_target), edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
训练、验证与测试
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
def train():
model.train()
total_loss = total_nodes = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
nodes = batch.train_mask.sum().item()
total_loss += loss.item() * nodes
total_nodes += nodes
return total_loss / total_nodes
@torch.no_grad()
def test(): # Inference should be performed on the full graph.
model.eval()
out = model.inference(data.x)
y_pred = out.argmax(dim=-1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = y_pred[mask].eq(data.y[mask]).sum().item()
accs.append(correct / mask.sum().item())
return accs
for epoch in range(1, 31):
loss = train()
if epoch % 5 == 0:
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, test: {test_acc:.4f}')
else:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')