Task05:超大图上的节点表征学习
本文参考datawhale开源学习资料
一、Cluster-GCN简介
为了解决普通训练方法无法训练超大图的问题,Cluster-GCN论文提出:
- 利用图节点聚类算法将一个图的节点划分为 c c c个簇,每一次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练。
- 由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率。
- 每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会有batch内类别分布偏差过大的问题。
- 基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,进而可以达到更高的精度。
二、Cluster-GCN分析
详细可见datawhale开源资料
1. 以往训练方法的瓶颈
模型 | 训练方法 | 空间复杂度 | 时间复杂度 | 收敛性 |
---|---|---|---|---|
GCN | full-batch gradient descent | 差 | 好 | 差 |
GraphSAGE | mini-batch gradient descent | 好 | 差 | 好 |
VR-GCN | reduce size of neighborhood sampleing node | 差 | 好 | 好 |
2. Vanilla Cluster-GCN
对于一个图
G
G
G,我们将其节点划分为
c
c
c个簇:
V
=
[
V
1
,
⋯
V
c
]
\mathcal{V}=\left[\mathcal{V}_{1}, \cdots \mathcal{V}_{c}\right]
V=[V1,⋯Vc],其中
V
t
\mathcal{V}_{t}
Vt由第
t
t
t个簇中的节点组成,对应的我们有
c
c
c个子图:
G
ˉ
=
[
G
1
,
⋯
,
G
c
]
=
[
{
V
1
,
E
1
}
,
⋯
,
{
V
c
,
E
c
}
]
\bar{G}=\left[G_{1}, \cdots, G_{c}\right]=\left[\left\{\mathcal{V}_{1}, \mathcal{E}_{1}\right\}, \cdots,\left\{\mathcal{V}_{c}, \mathcal{E}_{c}\right\}\right]
Gˉ=[G1,⋯,Gc]=[{V1,E1},⋯,{Vc,Ec}]
其中
E
t
\mathcal{E}_{t}
Et只由
V
t
\mathcal{V}_{t}
Vt中的节点之间的边组成。经过节点重组,邻接矩阵被划分为大小为
c
2
c^{2}
c2的块矩阵,如下所示
A
=
A
ˉ
+
Δ
=
[
A
11
⋯
A
1
c
⋮
⋱
⋮
A
c
1
⋯
A
c
c
]
(1)
A=\bar{A}+\Delta=\left[\begin{array}{ccc} A_{11} & \cdots & A_{1 c} \\ \vdots & \ddots & \vdots \\ A_{c 1} & \cdots & A_{c c} \end{array}\right] \tag{1}
A=Aˉ+Δ=⎣⎢⎡A11⋮Ac1⋯⋱⋯A1c⋮Acc⎦⎥⎤(1)
其中
A
ˉ
=
[
A
11
⋯
0
⋮
⋱
⋮
0
⋯
A
c
c
]
,
Δ
=
[
0
⋯
A
1
c
⋮
⋱
⋮
A
c
1
⋯
0
]
(2)
\bar{A}=\left[\begin{array}{ccc} A_{11} & \cdots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \cdots & A_{c c} \end{array}\right], \Delta=\left[\begin{array}{ccc} 0 & \cdots & A_{1 c} \\ \vdots & \ddots & \vdots \\ A_{c 1} & \cdots & 0 \end{array}\right] \tag{2}
Aˉ=⎣⎢⎡A11⋮0⋯⋱⋯0⋮Acc⎦⎥⎤,Δ=⎣⎢⎡0⋮Ac1⋯⋱⋯A1c⋮0⎦⎥⎤(2)
其中,对角线上的块
A
t
t
A_{t t}
Att是大小为
∣
V
t
∣
×
∣
V
t
∣
\left|\mathcal{V}_{t}\right| \times\left|\mathcal{V}_{t}\right|
∣Vt∣×∣Vt∣的邻接矩阵,它由
G
t
G_{t}
Gt内部的边构成。
A
ˉ
\bar{A}
Aˉ是图
G
ˉ
\bar{G}
Gˉ的邻接矩阵。
A
s
t
A_{s t}
Ast由两个簇
V
s
\mathcal{V}_{s}
Vs和
V
t
\mathcal{V}_{t}
Vt之间的边构成。
Δ
\Delta
Δ是由
A
A
A的所有非对角线块组成的矩阵。同样,我们可以根据
[
V
1
,
⋯
,
V
c
]
\left[\mathcal{V}_{1}, \cdots, \mathcal{V}_{c}\right]
[V1,⋯,Vc]划分节点表征矩阵
X
X
X和类别向量
Y
Y
Y,得到
[
X
1
,
⋯
,
X
c
]
\left[X_{1}, \cdots, X_{c}\right]
[X1,⋯,Xc]和
[
Y
1
,
⋯
,
Y
c
]
\left[Y_{1}, \cdots, Y_{c}\right]
[Y1,⋯,Yc],其中
X
t
X_{t}
Xt和
Y
t
Y_{t}
Yt分别由
V
t
V_{t}
Vt中节点的表征和类别组成。
接下来我们用块对角线邻接矩阵
A
ˉ
\bar{A}
Aˉ去近似邻接矩阵
A
A
A,这样做的好处是,完整的损失函数可以根据batch分解成多个部分之和。以
A
ˉ
′
\bar{A}^{\prime}
Aˉ′表示归一化后的
A
ˉ
\bar{A}
Aˉ,最后一层节点表征矩阵可以做如下的分解:
Z
(
L
)
=
A
ˉ
′
σ
(
A
ˉ
′
σ
(
⋯
σ
(
A
ˉ
′
X
W
(
0
)
)
W
(
1
)
)
⋯
)
W
(
L
−
1
)
=
[
A
ˉ
11
′
σ
(
A
ˉ
11
′
σ
(
⋯
σ
(
A
ˉ
11
′
X
1
W
(
0
)
)
W
(
1
)
)
⋯
)
W
(
L
−
1
)
⋮
A
ˉ
c
c
′
σ
(
A
ˉ
c
c
′
σ
(
⋯
σ
(
A
ˉ
c
c
′
X
c
W
(
0
)
)
W
(
1
)
)
⋯
)
W
(
L
−
1
)
]
(3)
\begin{aligned} Z^{(L)} &=\bar{A}^{\prime} \sigma\left(\bar{A}^{\prime} \sigma\left(\cdots \sigma\left(\bar{A}^{\prime} X W^{(0)}\right) W^{(1)}\right) \cdots\right) W^{(L-1)} \\ &=\left[\begin{array}{c} \bar{A}_{11}^{\prime} \sigma\left(\bar{A}_{11}^{\prime} \sigma\left(\cdots \sigma\left(\bar{A}_{11}^{\prime} X_{1} W^{(0)}\right) W^{(1)}\right) \cdots\right) W^{(L-1)} \\ \vdots \\ \bar{A}_{c c}^{\prime} \sigma\left(\bar{A}_{c c}^{\prime} \sigma\left(\cdots \sigma\left(\bar{A}_{c c}^{\prime} X_{c} W^{(0)}\right) W^{(1)}\right) \cdots\right) W^{(L-1)} \end{array}\right] \end{aligned} \tag{3}
Z(L)=Aˉ′σ(Aˉ′σ(⋯σ(Aˉ′XW(0))W(1))⋯)W(L−1)=⎣⎢⎡Aˉ11′σ(Aˉ11′σ(⋯σ(Aˉ11′X1W(0))W(1))⋯)W(L−1)⋮Aˉcc′σ(Aˉcc′σ(⋯σ(Aˉcc′XcW(0))W(1))⋯)W(L−1)⎦⎥⎤(3)
由于
A
ˉ
\bar{A}
Aˉ是块对角形式(
A
ˉ
t
t
′
\bar{A}_{t t}^{\prime}
Aˉtt′是
A
ˉ
′
\bar{A}^{\prime}
Aˉ′的对角线上的块),于是损失函数可以分解为
L
A
ˉ
′
=
∑
t
∣
V
t
∣
N
L
A
ˉ
t
t
′
and
L
A
ˉ
t
t
′
=
1
∣
V
t
∣
∑
i
∈
V
t
loss
(
y
i
,
z
i
(
L
)
)
(4)
\mathcal{L}_{\bar{A}^{\prime}}=\sum_{t} \frac{\left|\mathcal{V}_{t}\right|}{N} \mathcal{L}_{\bar{A}_{t t}^{\prime}} \text { and } \mathcal{L}_{\bar{A}_{t t}^{\prime}}=\frac{1}{\left|\mathcal{V}_{t}\right|} \sum_{i \in \mathcal{V}_{t}} \operatorname{loss}\left(y_{i}, z_{i}^{(L)}\right) \tag{4}
LAˉ′=t∑N∣Vt∣LAˉtt′ and LAˉtt′=∣Vt∣1i∈Vt∑loss(yi,zi(L))(4)
基于公式(3)和公式(4),在训练的每一步中,Cluster-GCN首先采样一个簇
V
t
\mathcal{V}_{t}
Vt,然后根据
L
A
ˉ
′
t
t
\mathcal{L}_{{\bar{A}^{\prime}}_{tt}}
LAˉ′tt的梯度进行参数更新。这种训练方式,只需要用到子图
A
t
t
A_{t t}
Att,
X
t
X_{t}
Xt,
Y
t
Y_{t}
Yt以及神经网络权重矩阵
{
W
(
l
)
}
l
=
1
L
\left\{W^{(l)}\right\}_{l=1}^{L}
{W(l)}l=1L。 实际中,主要的计算开销在神经网络前向过程中的矩阵乘法运算(公式(3)的一个行)和梯度反向传播。
我们使用图节点聚类算法来划分图。图节点聚类算法将图节点分成多个簇,划分结果是簇内边的数量远多于簇间边的数量。如前所述,每个batch的表征利用率相当于簇内边的数量。直观地说,每个节点和它的邻接节点大部分情况下都位于同一个簇中,因此** L L L跳(L-hop)远的邻接节点大概率仍然在同一个簇中**。由于我们用块对角线近似邻接矩阵 A ˉ \bar{A} Aˉ代替邻接矩阵 A A A,产生的误差与簇间的边的数量 Δ \Delta Δ成正比,所以簇间的边越少越好。综上所述,使用图节点聚类算法对图节点划分多个簇的结果,正是我们希望得到的。
图1展示了Cluster-GCN方法可以避免巨大范围的邻域扩展(图右),因为Cluster-GCN方法将邻域扩展限制在簇内。
表1显示了两种不同的节点划分策略:随机划分与聚类划分。两者都使用一个分区作为一个batch来进行神经网络训练。我们可以看到,在相同的epoches下,使用聚类分区可以达到更高的精度。
表1:随机分区与聚类分区的对比(采用mini-batch SGD训练)。聚类分区得到更好的性能(就测试F1集得分而言),因为它删除的分区间的边较少。
Dataset | random partition | clustering partition |
---|---|---|
Cora | 78.4 | 82.5 |
Pubmed | 78.9 | 79.9 |
PPI | 68.1 | 92.9 |
2. 随机多分区(Stochastic Multiple Partitions)
尽管简单Cluster-GCN方法可以做到较其他方法更低的计算和内存复杂度,但它仍存在两个潜在问题:
- 图被分割后,一些边(公式(1)中的 Δ \Delta Δ部分)被移除,性能可能因此会受到影响。
- 图聚类算法倾向于将相似的节点聚集在一起。因此,单个簇中节点的类别分布可能与原始数据集不同,导致对梯度的估计有偏差。
图2展示了一个类别分布不平衡的例子,该例子使用Reddit数据集,节点聚类由Metis软件包实现。根据各个簇的类别分布来计算熵值。与随机划分相比,采用聚类划分得到的大多数簇熵值都很小,簇熵值小表明簇中节点的标签分布偏向于某一些类别,这意味着不同簇的标签分布有较大的差异,这将影响训练的收敛。
图2:类别分布熵值柱状图。类别分布熵越高意味着簇内类别分布越平衡,反之意味着簇内类别分布越不平衡。此图展示了不同熵值的随机分区和聚类分区的簇的数量,大多数聚类分区的簇具有较低的熵,表明各个簇内节点的类别分布存在偏差。相比之下,随机分区会产生类别分布熵很高的簇,尽管基于随机分区的训练的效率较低。在这个例子中,使用了Reddit数据集,进行了300个簇的分区。
为了解决上述问题,Cluster-GCN论文提出了一种随机多簇方法,此方法首先将图划分为 p p p个簇, V 1 , ⋯ , V p \mathcal{V}_{1}, \cdots, \mathcal{V}_{p} V1,⋯,Vp, p p p是一个较大的值,在构建一个batch时,不是只使用一个簇,而是使用随机选择的 q q q个簇,表示为 t 1 , … , t q t_{1}, \ldots, t_{q} t1,…,tq,得到的batch包含节点 { V t 1 ∪ ⋯ ∪ V t q } \left\{\mathcal{V}_{t_{1}} \cup \cdots \cup \mathcal{V}_{t_{q}}\right\} {Vt1∪⋯∪Vtq} 、簇内边 { A i i ∣ i ∈ t 1 , … , t q } \left\{A_{i i} \mid i \in t_{1}, \ldots, t_{q}\right\} {Aii∣i∈t1,…,tq}和簇间边 { A i j ∣ i , j ∈ t 1 , … , t q } \left\{A_{i j} \mid i, j \in t_{1}, \ldots, t_{q}\right\} {Aij∣i,j∈t1,…,tq}。此方法的好处有,1)不会丢失簇间的边,2)不会有很大的batch内类别分布的偏差,3)以及不同的epoch使用的batch不同,这可以降低梯度估计的偏差。
图3展示了随机多簇方法,在每个epoch中,随机选择一些簇来组成一个batch,不同的epoch的batch不同。在图4中,我们可以观察到,使用多个簇来组成一个batch可以提高收敛性。
图3:Cluster-GCN提出的随机多分区方法。在每个epoch中,我们(不放回地)随机抽取
q
q
q个簇(本例中使用
q
q
q=2)及其簇间的边,来构成一个batch(相同颜色的块在同一batch中)。
图4:选择一个簇与选择多个簇的比较。前者使用300个簇。后者使用1500个簇,并随机选择5个簇来组成一个batch。该图X轴为epoches,Y轴为F1得分。
3. 训练深层GCNs的问题
以往尝试训练更深的GCN的研究似乎表明,增加更多的层是没有帮助的。然而,那些研究的实验使用的图太小,所以结论可能并不正确。例如,其中有一项研究只使用了一个只有几百个训练节点的图,由于节点数量过少,很容易出现过拟合的问题。此外,加深GCN神经网络层数后,训练变得很困难,因为层数多了之后前面的信息可能无法传到后面。有的研究采用了一种类似于残差连接的技术,使模型能够将前一层的信息直接传到下一层。如下所示:
X
(
l
+
1
)
=
σ
(
A
′
X
(
l
)
W
(
l
)
)
+
X
(
l
)
(5)
X^{(l+1)}=\sigma\left(A^{\prime} X^{(l)} W^{(l)}\right)+X^{(l)} \tag{5}
X(l+1)=σ(A′X(l)W(l))+X(l)(5)
在这里,我们提出了另一种简单的技术来改善深层GCN神经网络的训练。在原始的GCN的设置里,每个节点都聚合邻接节点在上一层的表征。然而,在深层GCN的设置里,该策略可能不适合,因为它没有考虑到层数的问题。直观地说,近距离的邻接节点应该比远距离的的邻接节点贡献更大。因此,Cluster-GCN提出一种技术来更好地解决这个问题。其主要思想是放大GCN每一层中使用的邻接矩阵
A
A
A的对角线部分。通过这种方式,我们在GCN的每一层的聚合中对来自上一层的表征赋予更大的权重。这可以通过给
A
ˉ
\bar{A}
Aˉ加上一个单位矩阵
I
I
I来实现,如下所示,
X
(
l
+
1
)
=
σ
(
(
A
′
+
I
)
X
(
l
)
W
(
l
)
)
(6)
X^{(l+1)}=\sigma\left(\left(A^{\prime}+I\right) X^{(l)} W^{(l)}\right) \tag{6}
X(l+1)=σ((A′+I)X(l)W(l))(6)
虽然公式(6)似乎是合理的,但对所有节点使用相同的权重而不考虑其邻居的数量可能不合适。此外,它可能会受到数值不稳定的影响,因为当使用更多的层时,数值会呈指数级增长。因此,Cluster-GCN方法提出了一个修改版的公式(6),以更好地保持邻接节点信息和数值范围。首先给原始的
A
A
A添加一个单位矩阵
I
I
I,并进行归一化处理
A
~
=
(
D
+
I
)
−
1
(
A
+
I
)
(7)
\tilde{A}=(D+I)^{-1}(A+I) \tag{7}
A~=(D+I)−1(A+I)(7)
然后考虑,
X
(
l
+
1
)
=
σ
(
(
A
~
+
λ
diag
(
A
~
)
)
X
(
l
)
W
(
l
)
)
(8)
X^{(l+1)}=\sigma\left((\tilde{A}+\lambda \operatorname{diag}(\tilde{A})) X^{(l)} W^{(l)}\right) \tag{8}
X(l+1)=σ((A~+λdiag(A~))X(l)W(l))(8)
三、Cluster-GCN实践
1. 数据集分析
from torch_geometric.datasets import Reddit
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler
dataset = Reddit('../dataset/Reddit')
data = dataset[0]
print(dataset.num_classes)
print(data.num_nodes)
print(data.num_edges)
print(data.num_features)
运行结果:
Downloading https://data.dgl.ai/dataset/reddit.zip
Extracting ..\dataset\Reddit\raw\reddit.zip
Processing...
Done!
41
232965
114615892
602
可以看到该数据集包含41个分类任务,232,965个节点,114,615,873条边,节点维度为602维。
2. 图节点聚类与数据加载器生成
cluster_data = ClusterData(data, num_parts=1500, recursive=False, save_dir=dataset.processed_dir)
train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=12)
subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, shuffle=False, num_workers=12)
运行结果:
Computing METIS partitioning...
Done!
train_loader
,此数据加载器遵循Cluster-GCN提出的方法,图节点被聚类划分成多个簇,此数据加载器返回的一个batch由多个簇组成。
subgraph_loader
,使用此数据加载器不对图节点聚类,计算一个batch中的节点的表征需要计算该batch中的所有节点的距离从
0
0
0到
L
L
L的邻居节点。
3. 图神经网络的构建
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from tqdm import tqdm
from torch_geometric.nn import SAGEConv
class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Net, self).__init__()
self.convs = ModuleList(
[SAGEConv(in_channels, 128),
SAGEConv(128, out_channels)])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return F.log_softmax(x, dim=-1)
def inference(self, x_all):
pbar = tqdm(total=x_all.size(0) * len(self.convs))
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
for i, conv in enumerate(self.convs):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = conv((x, x_target), edge_index)
if i != len(self.convs) - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
model = Net(dataset.num_features, dataset.num_classes).to(device)
print(model)
可以看到此神经网络拥有forward
和inference
两个方法。forward
函数的定义与普通的图神经网络并无区别。inference
方法应用于推理阶段,为了获取更高的预测精度,所以使用subgraph_loader
。
运行结果:
Net(
(convs): ModuleList(
(0): SAGEConv(602, 128)
(1): SAGEConv(128, 41)
)
)
4. 训练、验证与测试
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
def train():
model.train()
total_loss = total_nodes = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
nodes = batch.train_mask.sum().item()
total_loss += loss.item() * nodes
total_nodes += nodes
return total_loss / total_nodes
@torch.no_grad()
def test(): # Inference should be performed on the full graph.
model.eval()
out = model.inference(data.x)
y_pred = out.argmax(dim=-1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = y_pred[mask].eq(data.y[mask]).sum().item()
accs.append(correct / mask.sum().item())
return accs
for epoch in range(1, 31):
loss = train()
if epoch % 5 == 0:
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, test: {test_acc:.4f}')
else:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
可见在训练过程中,我们使用train_loader
获取batch,每次根据多个簇组成的batch进行神经网络的训练。但在验证阶段,我们使用subgraph_loader
,在计算一个节点的表征时会计算该节点的距离从
0
0
0到
L
L
L的邻接节点,这么做可以更好地测试神经网络的性能。
运行结果:
Epoch: 01, Loss: 1.1692
Epoch: 02, Loss: 0.4743
Epoch: 03, Loss: 0.3937
Epoch: 04, Loss: 0.3554
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [18:36<00:00, 417.46it/s]
Epoch: 05, Loss: 0.3465, Train: 0.9570, Val: 0.9557, test: 0.9527
Epoch: 06, Loss: 0.3177
Epoch: 07, Loss: 0.3175
Epoch: 08, Loss: 0.3054
Epoch: 09, Loss: 0.2904
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [18:13<00:00, 426.15it/s]
Epoch: 10, Loss: 0.3034, Train: 0.9530, Val: 0.9456, test: 0.9439
Epoch: 11, Loss: 0.2816
Epoch: 12, Loss: 0.2738
Epoch: 13, Loss: 0.2745
Epoch: 14, Loss: 0.2858
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [17:41<00:00, 439.03it/s]
Epoch: 15, Loss: 0.2681, Train: 0.9657, Val: 0.9549, test: 0.9521
Epoch: 16, Loss: 0.2662
Epoch: 17, Loss: 0.2626
Epoch: 18, Loss: 0.2564
Epoch: 19, Loss: 0.2780
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [17:39<00:00, 439.92it/s]
Epoch: 20, Loss: 0.2623, Train: 0.9639, Val: 0.9477, test: 0.9466
Epoch: 21, Loss: 0.2503
Epoch: 22, Loss: 0.2437
Epoch: 23, Loss: 0.2382
Epoch: 24, Loss: 0.2426
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [17:38<00:00, 440.08it/s]
Epoch: 25, Loss: 0.2419, Train: 0.9680, Val: 0.9523, test: 0.9512
Epoch: 26, Loss: 0.2437
Epoch: 27, Loss: 0.2693
Epoch: 28, Loss: 0.2393
Epoch: 29, Loss: 0.2305
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [17:38<00:00, 440.00it/s]
Epoch: 30, Loss: 0.2307, Train: 0.9721, Val: 0.9541, test: 0.9522
四、作业
- 尝试将数据集切分成不同数量的簇进行实验,然后观察结果并进行比较。
将num_parts=1500
改成num_parts=1000
进行实验,实验结果如下:
Epoch: 01, Loss: 1.3562
Epoch: 02, Loss: 0.5103
......
Epoch: 29, Loss: 0.2154
Evaluating: 100%|██████████| 465930/465930 [17:43<00:00, 441.52it/s]
Epoch: 30, Loss: 0.2126, Train: 0.9741, Val: 0.9542, test: 0.9530
可以观察到分成1000簇和分成1500簇的结果相差不大。