文章目录
0、基本介绍
- 作者:Zemin Liu,Xingtong Yu, Yuan Fang,Xinming Zhang
- 会议:2023-WWW
- 文章链接:GraphPrompt:Unifying pre-training and downstream tasks for graph neural networks.
- 代码链接:GraphPrompt:Unifying pre-training and downstream tasks for graph neural networks.
What?Why?How?
1、研究动机
图神经网络(GNNs)尽管是学习图表征的一个有效的方法,但是它的性能很大程度上依赖于大规模的指定任务的监督训练,然而,监督训练的数据的标签是匮乏的。为了克服标签稀缺的问题,提出了“pretrain,fine-tune”的学习范式,希望预训练好一个模型,在下游少量标签的数据上进行微调,就能取得较好性能。
但是,由于图上三种不同级别的任务(node-,edge- and graph-level)训练目标的差距是很大的(预训练和下游任务之间的目标不一致的问题),一个链接预测任务不能很好的适应下游结点分类任务,这就导致微调的性能不是最佳的。
为了弥补预训练模型和下游任务之间的差距,提出了“pretrain,prompt,fine-tune”的范式(prompt tuning 也是微调的一种),希望通过prompt的方法,将下游数据来适应预训练模型,进而取得更好的训练效果。
本篇基于“pretrain,prompt,fine-tune”的范式,提出GraphPrompt少样本的提示框架,统一下游不同的任务目标,将下游任务目标和预训练模型目标对齐。
2、创新点
(1)统一上下游训练目标
上游训练是链接预测,两个结点的邻域表征的越相似,那么,这两个结点相连接的概率越大。
下游结点分类和图分类也统一为链接预测,预测结点/图的邻域表征与提示结点/图的表征计算余弦相似度,取相似度最大的对应的类别作为预测标签。
使用一个共同的训练目标统一上下游任务可以很好的做到知识的迁移,同时上下游任务也更加兼容。
(2)区分不同下游任务
通过区分不同下游任务能够捕获任务差别并实现在指定任务上任是最优的。
作者基于直觉上的假设,不同的任务受益于不同的聚合方案,节点分类更加关注与目标结点在特征上更相关的特征。图分类倾向于与图类别相关的特征。作者提出prompt-assisted READOUT操作,用于区分不同的下游任务。
自我感觉创新之处并不是很大,很多都是基于直觉上的假设,要是有合理的证明就更好了(手动滑稽);一个可以学习的一点是提示方法,之前的基于提示文章例如,GPF和All-in-One,一个是将提示向量加在特征中,另一个是将GPF的提示向量改为了提示图,并统一下游任务。而本篇是统一上下游任务,并将提示加在了READOUT阶段,用于区分不同任务。直觉上的想法是可以认可的,但背后的原因还需要进一步讨论。(仅个人观点,欢迎讨论)
3、Method
3.1、Unification Framework
统一预训练和下游任务的关键在于为不同的任务找到一个共同的模板,然后将特定于任务提示与每个下游任务的模板进一步融合,来区分不同任务的不同特征。
与NLP和CV领域不同,图学习的独特之处在于利用图的拓扑结构——子图可以作为表达结点和图的一种通用结构。
图 G = ( V , E ) G=(V,E) G=(V,E)中任意一个结点v的局部邻域表示为 S v = ( V ( S v ) , E ( S v ) ) S_v=(V(S_v),E(S_v)) Sv=(V(Sv),E(Sv)),这个局部邻域可以是k-hop局部邻域,子图 S v S_v Sv不仅包含结点 v v v的信息,还包含丰富的上下文信息。图级任务,图的最大子图即为图本身(i.e. S G = G S_G= G SG=G),包含了所有的信息。所以,子图可以用于表示结点级和图级实例:给定结点/图实例 x x x,子图 S x S_x Sx提供了与实例 x x x相关的可以利用的信息。
统一的任务模板。基于上述图级和结点级子图的定义,可以将结点级和图级任务统一为一个统一的模板。也就是说,预训练的链接预测任务以及下游的结点和图分类任务可以重新定义为子图的相似性学习。
s
x
s_x
sx为子图
S
x
S_x
Sx的向量表征,
sim(.,.)
\text{sim(.,.)}
sim(.,.)为余弦相似度函数。
3.1.1、link predication
给定图
G
=
(
V
,
E
)
G=(V,E)
G=(V,E)和一个三元组
(
v
,
a
,
b
)
(v,a,b)
(v,a,b),其中
(
v
,
a
)
∈
E
and
(
v
,
b
)
∉
E
(v,a) \in E \;\text{and}\; (v,b) \notin E
(v,a)∈Eand(v,b)∈/E,若
sim
(
s
v
,
s
a
)
>
sim
(
s
v
,
s
b
)
\text{sim}(s_v,s_a) > \text{sim}(s_v,s_b)
sim(sv,sa)>sim(sv,sb)
直觉上,与另一个未连接的节点对相比,相互连接的结点对之间的子图表征更相似。
3.1.2、node classification
对于一个拥有 C C C个类别的图 G = ( V , E ) G=(V,E) G=(V,E),被标记结点集合 D = { ( v 1 , l 1 ) , ( v 2 , l 2 ) , … } D=\{(v_1,\mathcal{l}_1),(v_2,\mathcal{l}_2),\dots\} D={(v1,l1),(v2,l2),…},其中, v i ∈ V v_i \in V vi∈V, l i \mathcal{l}_i li是结点 v i v_i vi的标签,这里采用k-shot prompt,所以这里有k个 ( v i , l i = c ) ∈ D , c ∈ C (v_i,\mathcal{l}_i=c)\in D,c\in C (vi,li=c)∈D,c∈C。对于每个类 c ∈ C c\in C c∈C,一个结点类原型子图(node class prototypical subgraph)表示为向量 s ~ c \tilde{s}_c s~c
s ~ c = 1 2 ∑ ( v i , l i ) ∈ D , l i = c ( s v i ) \tilde{s}_c=\frac{1}{2}\sum_{(v_i,\mathcal{l}_i)\in D,\mathcal{l}_i=c}(s_{v_i}) s~c=21(vi,li)∈D,li=c∑(svi)
类原型子图是一个“虚拟”子图,它想要表达与节点上下文子图在相同的嵌入空间中。它定义为给定类中被标记节点的上下文子图的平均表征。所以,给定一个未被标记的结点
v
j
v_j
vj,他的标签
l
j
\mathcal{l}_j
lj为
l
j
=
arg
max
c
∈
C
sin
(
s
v
j
,
s
~
c
)
\mathcal{l}_{j}=\arg\max_{c\in C}\sin(s_{v_{j}},\tilde{s}_{c})
lj=argc∈Cmaxsin(svj,s~c)
直觉上,一个节点应该属于其原型子图与该节点的上下文子图最相似的类。
3.1.3、graph classification
对于一个拥有 C C C个类别的图集合 G \mathcal{G} G,被标记图集合 D = { ( G 1 , L 1 ) , ( G 2 , L 3 ) , … } \mathcal{D}=\{(G_1,L_1),(G_2,L_3),\dots\} D={(G1,L1),(G2,L3),…}, G i ∈ G G_i\in\mathcal{G} Gi∈G, L i L_i Li是与之对应的标签,因为这是k-shot prompt ,所以被标记集合 D \mathcal{D} D由k个元素。对于每个类别 c ∈ C c\in \mathcal{C} c∈C,定义一个graph class prototypical subgraph,也由子图的平均嵌入向量表示:
s
~
c
=
1
k
∑
(
G
i
,
L
i
)
∈
D
,
L
i
=
c
s
G
i
\tilde{\mathbf{s}}_{c}=\frac{1}{k}\sum_{(G_{i},L_{i})\in\mathcal{D},L_{i}=c}s_{G_{i}}
s~c=k1(Gi,Li)∈D,Li=c∑sGi
然后,给定一个不在标记集合
D
\mathcal{D}
D中的图
G
j
\mathcal{G}_j
Gj,它的类标记应该是
L
j
L_j
Lj
L
j
=
arg
max
c
∈
C
sin
(
s
G
j
,
s
~
c
)
L_{j}=\arg\max_{c\in C}\sin(s_{G_{j}},\tilde{s}_{c})
Lj=argc∈Cmaxsin(sGj,s~c)
直觉上,一个图应该属于其原型子图与它自己最相似的类。
3.1.4、unified representation
值得注意的是,节点和图分类可以以统一的形式表示。假设
(
x
,
y
)
(x,y)
(x,y)为图数据实例,
x
x
x要么是结点要么是图,
y
∈
Y
y\in Y
y∈Y是
x
x
x的类标签。那么,
y
=
arg
max
c
∈
Y
sin
(
s
x
,
s
~
c
)
y=\arg\max_{c\in Y}\sin(s_{x},\tilde{s}_{c})
y=argc∈Ymaxsin(sx,s~c)
最后,对于一个通过GNN生成的结点表征
h
v
h_v
hv,常见的计算
s
x
s_x
sx的方法是采用READOUT操作,聚合子图
S
x
S_x
Sx中节点的表示:
s
x
=
R
E
A
D
O
U
T
(
{
h
v
:
v
∈
V
(
S
x
)
}
)
\mathbf{s}_{x}=\mathrm{READOUT}(\{\mathbf{h}_{v}:v\in V(S_{x})\})
sx=READOUT({hv:v∈V(Sx)})
聚合操作有不同的方法,这里作者简单的采用sum pooling。
3.2、pretrain phase
预训练阶段使用连接预测任务,连接关系存在于大多数图中,因而很容易获得,无需额外的标注信息,这样就可以以自监督的方式在无标签图上优化链接预测的目标。
作者构建损失函数的直觉前面我们已经提及:链接预测任务两个候选节点的上下文子图的相似性。通常,两个正例的子图(即,相连的)应当比那些负例的子图(即,无连接)更相似。同时,子图相似性的预先训练的先验知识可以自然地转移到下游的节点分类,类似的直觉:同一类中的节点的子图应该比来自不同类的节点的子图更相似。
这种直觉也可以支持图分类。
给定图
G
G
G上的结点
v
v
v,随机从
v
v
v的邻域中采用一个结点
a
a
a,从与结点
v
v
v不相连的结点集合中采样结点
b
b
b。目标是增加上下文子图
S
v
S_v
Sv和
S
a
S_a
Sa之间的相似性,而减少
S
v
S_v
Sv和
S
b
S_b
Sb之间的相似性。
更一般地,在一组无标签图
G
G
G上,我们从每个图中采样多个三元组以构建整体训练集
T
p
r
e
\mathcal{T}_{pre}
Tpre。然后,定义预训练损失:
L
p
r
e
(
Θ
)
=
−
∑
(
v
,
a
,
b
)
∈
T
p
r
e
ln
exp
(
sin
(
s
v
,
s
a
)
/
τ
)
∑
u
∈
{
a
,
b
}
exp
(
sin
(
s
v
,
s
u
)
/
τ
)
\mathcal{L}_{\mathrm{pre}}(\Theta)=-\sum_{(v,a,b)\in\mathcal{T}_{\mathrm{pre}}}\ln\frac{\exp(\sin(s_v,s_a)/\tau)}{\sum_{u\in\{a,b\}}\exp(\sin(s_v,s_u)/\tau)}
Lpre(Θ)=−(v,a,b)∈Tpre∑ln∑u∈{a,b}exp(sin(sv,su)/τ)exp(sin(sv,sa)/τ)
τ
\tau
τ是温度参数,
Θ
\Theta
Θ是GNN模型的权重参数。预训练阶段的输出是最优模型参数
Θ
\Theta
Θ可用于初始化下游任务的GNN权重,从而实现先验知识向下游的转移。
3.3、prompting for downsteam tasks
尽管上游任务和下游任务统一可以实现知识的有效迁移,但是区分不同的任务仍然是很重要的,这样能够捕获任务差别并实现在指定任务上任是最优的。
作者认为,不同的任务受益于不同的聚合方案,节点分类更加关注与目标结点在特征上更相关的特征。图分类倾向于与图类别相关的特征。作者提出prompt-assisted READOUT操作。
具体来说,
p
t
\text{p}_t
pt定义为用于下游任务
t
t
t的一个可以学习的提示向量,对于任务
t
t
t,在子图
S
x
S_x
Sx上的prompt-assisted READOUT操作为:
s
t
,
x
=
R
E
A
D
O
U
T
(
{
p
t
⊙
h
v
:
v
∈
V
(
S
x
)
}
)
\mathbf{s}_{t,x}=\mathrm{READOUT}(\{\mathbf{p}_{t}\odot\mathbf{h}_{v}:v\in V(S_{x})\})
st,x=READOUT({pt⊙hv:v∈V(Sx)})
s t , x s_{t,x} st,x是任务 t t t子图的表征, ⊙ \odot ⊙定义为element-wise multiplication,也就是说,我们对来自子图的节点表示进行特征加权求和,其中提示向量 p t \text{p}_t pt是维度上的重新加权,以便提取任务的最相关的先验知识。
提示微调。为了优化可学习提示,也被称为提示调整,我们制定了基于子图相似性的公共模板的损失,使用辅助的特定于任务的子图表示。
具体来说,具有标记集合
T
t
=
{
(
x
1
,
y
1
)
,
(
x
2
,
y
2
)
,
…
}
\mathcal{T}_t=\{(x_1,y_1),(x_2,y_2),\dots\}
Tt={(x1,y1),(x2,y2),…}的任务
t
t
t,
x
i
x_i
xi是一个实例(一个结点或一个图),
y
i
∈
Y
y_i\in Y
yi∈Y是
x
i
x_i
xi的类标签。提示微调的损失函数定义为
L
p
r
o
m
p
t
(
p
t
)
=
−
∑
(
x
i
,
y
i
)
∈
T
t
ln
exp
(
sin
(
s
t
,
x
i
,
s
~
t
,
y
i
)
/
τ
)
∑
c
∈
Y
exp
(
sin
(
s
t
,
x
i
,
s
~
t
,
c
)
/
τ
)
,
\mathcal{L}_{\mathrm{prompt}}(\mathbf{p}_{t})=-\sum_{(x_{i},y_{i})\in\mathcal{T}_{t}}\ln\frac{\exp(\sin(\mathbf{s}_{t,x_{i}},\tilde{\mathbf{s}}_{t,y_{i}})/\tau)}{\sum_{c\in Y}\exp(\sin(\mathbf{s}_{t,x_{i}},\tilde{\mathbf{s}}_{t,c})/\tau)},
Lprompt(pt)=−(xi,yi)∈Tt∑ln∑c∈Yexp(sin(st,xi,s~t,c)/τ)exp(sin(st,xi,s~t,yi)/τ),
其中,对于类 c c c的类原型子图被表示为 s ~ t , c \tilde{\mathbf{s}}_{t,c} s~t,c。注意,提示微调仅优化可学习提示向量 p t p_t pt参数化,而没有GNN权重。预先训练的GNN权重 Θ \Theta Θ被冻结用于下游任务不需要微调。
模型整体框架如下图所示。