论文地址:Bag of Tricks for Node Classification with Graph Neural Networks
一.概述
本文作者总结了前人关于图上半监督节点分类任务的常用Tricks,另外还提出了将节点特征和节点标签组合来进行训练和更鲁棒的损失函数,结果表明作者的设计是有效的,可供参考。
二.背景
首先给出论文的符号表:
符号 | 说明 |
---|---|
G = ( V , E ) G=(V,E) G=(V,E) | 图, V = { v 1 , . . . , v N } V=\{v_1,...,v_N\} V={v1,...,vN}为顶点集, E E E为边集 |
A \bold{A} A | 邻接矩阵 |
D D D | 度矩阵(对角阵) |
X = ( x 1 , . . . , x N ) T \bold{X}=(x_1,...,x_N)^T X=(x1,...,xN)T | 节点特征 |
Y = ( y 1 , . . . , y N ) T ∈ R N × C \bold{Y}=(y_1,...,y_N)^T \in \mathbb{R}^{N \times C} Y=(y1,...,yN)T∈RN×C | 标签矩阵(one-hot), C C C为类别数 |
S = D − 1 / 2 A D − 1 / 2 \bold{S} = D^{-1/2}\bold{A}D^{-1/2} S=D−1/2AD−1/2 | 正则化后的邻接矩阵(对称) |
作者以图中前 M M M个节点作为训练集。
作者将数据集分为训练集和测试集。
标签传播算法
标签传播算法(Label Propagation Algorithm, LPA)的动机是相邻的节点可能具有相似的标签。LPA通过迭代计算如下公式:
Y
(
k
+
1
)
=
λ
S
Y
(
k
)
+
(
1
−
λ
)
Y
(
0
)
Y^{(k+1)}=\lambda S Y^{(k)}+(1-\lambda) Y^{(0)}
Y(k+1)=λSY(k)+(1−λ)Y(0)
来求解线性系统
Y
∗
=
(
1
−
λ
)
(
I
−
λ
S
)
−
1
Y
Y^{*}=(1-\lambda)(I- \lambda S)^{-1} Y
Y∗=(1−λ)(I−λS)−1Y 。其中
Y
(
0
)
Y^{(0)}
Y(0)是标签矩阵,训练集中节点的标签保持不变,但测试节点的标签全部填为0。从LPA的传播公式可以看出,LPA算法没能够利用节点的特征。
图神经网络
GNNs是深度学习在图上学习的范式,其中最有名的模型之一便是GCN,其传播规则如下:
X
(
l
+
1
)
=
σ
(
D
−
1
2
A
D
−
1
2
X
(
l
)
W
(
l
)
)
\boldsymbol{X}^{(l+1)}=\sigma\left(D^{-\frac{1}{2}} A D^{-\frac{1}{2}} X^{(l)} W^{(l)}\right)
X(l+1)=σ(D−21AD−21X(l)W(l))
在GCN之后,各种图卷积神经网络开始涌现。与LPA算法不同的是,在推断测试节点的标签时,GNN并不显式利用训练节点的真实标签。
标签和特征传播的组合
既然单独使用标签传播或单独使用节点特征都能取得良好的性能,那将二者进行组合是有意义的。之前也有工作在这方面进行了探索,典型的工作便是C&S,但该方法不是端到端的。
三.半监督节点分类的Tricks
3.1 已经存在的Tricks
采样(Sampling):采样技术能使得GNN往大图上扩展。
数据增强(Data Augmentation):数据增强可以用来缓解过拟合和过平滑。
重正则化(Renormalization):GCN提出来的trick,即为图节点添加自环。
残差连接(Residual Connections):在GCN的传播规则中添加线性连接,即:
X
(
l
+
1
)
=
σ
(
D
~
−
1
2
A
~
D
~
−
1
2
X
(
l
)
W
0
(
l
)
+
X
(
l
)
W
1
(
l
)
)
X^{(l+1)}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} X^{(l)} W_{0}^{(l)}+X^{(l)} W_{1}^{(l)}\right)
X(l+1)=σ(D~−21A~D~−21X(l)W0(l)+X(l)W1(l))
在这种形势下,线性分量保留了激活经过众多传播层也能区分的节点表示,使得GCN的表达能力更强并能克服过平滑问题。
3.2 新的Tricks
3.2.1 标签的使用
标签作为输入(Label as Input):作者提出了一种新的采样技术,允许GNN模型将标签信息作为输入来学习标签之间的相互关系,下图展示的是该算法训练过程的伪代码:
从伪代码中可以看出,作者将训练集拆分为了两部分 D train L \mathcal{D}_{\text {train }}^{L} Dtrain L和 D train U \mathcal{D}_{\text {train }}^{U} Dtrain U。然后作者将 D train U \mathcal{D}_{\text {train }}^{U} Dtrain U的标签设置为0(第3行),并预测其标签。具体来说就是在训练过程中 D train L \mathcal{D}_{\text {train }}^{L} Dtrain L的输入包含特征和标签,而 D train U \mathcal{D}_{\text {train }}^{U} Dtrain U的输入仅包含特征,然后以此来预测 D train U \mathcal{D}_{\text {train }}^{U} Dtrain U的标签,并计算loss然后通过反向传播学习GNN的参数。在最终的推断中,训练集的所有标签都将作为输入。
标签重用增强(Augmentation with Label Reuse):作者进一步提出了标签重用,具体做法是将先前迭代的预测软标签来作为输入,这种情况下 D train U \mathcal{D}_{\text {train }}^{U} Dtrain U中的节点的标签将不再是零值向量,而是上一次迭代的预测结果。(对应的是算法1中的5-8行)
3.2.2 用于分类的鲁棒损失函数
对于2分类任务,常见的loss函数为logistic loss。然而,logistic loss对异常值比较敏感,而非凸损失函数可能更稳健。为此,作者考虑削弱凸性条件,从而设计准凸损失以提高鲁棒性:
ϕ
ρ
−
logit
(
v
)
:
=
ρ
(
ϕ
logit
(
v
)
)
\phi_{\rho-\operatorname{logit}}(v):=\rho\left(\phi_{\operatorname{logit}}(v)\right)
ϕρ−logit(v):=ρ(ϕlogit(v))
其中
ρ
:
R
+
→
R
+
\rho: \mathbb{R}^{+} \rightarrow \mathbb{R}^{+}
ρ:R+→R+是非单减函数,表1总结了
ρ
(
⋅
)
\rho(\cdot)
ρ(⋅) 的设计。
Loge Loss是作者提出的,其中
ϵ
\epsilon
ϵ 是可调参数,在论文中固定为
1
−
log
2
1-\log 2
1−log2。Loge Loss也可以扩展到多分类任务,其对应的数学形式为:
ℓ
log
e
(
y
^
,
y
)
=
log
(
ϵ
−
log
exp
(
y
^
class
)
∑
i
=
1
C
exp
(
y
^
i
)
)
−
log
ϵ
\ell_{\log e}(\hat{\boldsymbol{y}}, \boldsymbol{y})=\log \left(\epsilon-\log \frac{\exp \left(\hat{y}_{\text {class }}\right)}{\sum_{i=1}^{C} \exp \left(\hat{y}_{i}\right)}\right)-\log \epsilon
ℓloge(y^,y)=log(ϵ−log∑i=1Cexp(y^i)exp(y^class ))−logϵ
其中
y
^
\hat{\boldsymbol{y}}
y^和
y
\boldsymbol{y}
y都是one-hot向量。
3.3 Trick的应用——GAT架构的调整
作者在GAT架构中加入了Residual Connections和Renormalization,得到如下形式:
X
(
l
+
1
)
=
σ
(
D
~
−
1
2
A
~
a
t
t
D
~
−
1
2
X
(
l
)
W
0
(
l
)
+
X
(
l
)
W
1
(
l
)
)
X^{(l+1)}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A}_{a t t} \tilde{D}^{-\frac{1}{2}} X^{(l)} W_{0}^{(l)}+X^{(l)} W_{1}^{(l)}\right)
X(l+1)=σ(D~−21A~attD~−21X(l)W0(l)+X(l)W1(l))
其中
A
a
t
t
=
D
α
A_{a t t}=D \alpha
Aatt=Dα,
α
i
j
(
l
)
=
exp
(
LeakyReLU
(
a
T
[
W
(
l
)
x
i
(
l
)
∥
W
(
l
)
x
j
(
l
)
]
)
)
∑
r
∈
N
(
v
i
)
exp
(
LeakyReLU
(
a
T
[
W
(
l
)
x
i
(
l
)
∥
W
(
l
)
x
r
(
l
)
]
)
)
\alpha_{i j}^{(l)}=\frac{\exp \left(\operatorname{LeakyReLU}\left(\boldsymbol{a}^{T}\left[W^{(l)} x_{i}^{(l)} \| W^{(l)} x_{j}^{(l)}\right]\right)\right)}{\sum_{r \in \mathcal{N}\left(v_{i}\right)} \exp \left(\operatorname{LeakyReLU}\left(\boldsymbol{a}^{T}\left[W^{(l)} x_{i}^{(l)} \| W^{(l)} x_{r}^{(l)}\right]\right)\right)}
αij(l)=∑r∈N(vi)exp(LeakyReLU(aT[W(l)xi(l)∥W(l)xr(l)]))exp(LeakyReLU(aT[W(l)xi(l)∥W(l)xj(l)]))。
作者还提出了其它变体,具体详见论文。
四.实验
数据集:Cora、Citeseer、Pubmed、ogbn-arxiv、ogbn-porteins、 ogbn-products、Reddit。数据集的特征如下所示:
表3报告了使用label as input和label reuse的性能,从结果可以看到作者设计的有效性。另外,将作者的方法与C&S结合可以进一步提升性能。
表4对比了不同损失函数的性能,实验中每个模型使用相同的超参数进行训练,仅改变损失函数。从结果可以看出作者设计的Loge Loss在大多数数据集上都表现良好。
五.结语
在本文的最后给大家安利一个开箱即用的GNN Tricks库:gtrick,里面集成了GNN的多种Tricks,还有使用示例,其安装命令为:
pip install gtrick
欢迎感兴趣的小伙伴自行探索。