主要问题: 现有 GNN 模型在下游任务中性能不佳,尤其是在标注数据稀缺的情况下。这主要归因于两个问题:
- 依赖大量标注数据: 现有 GNN 模型通常需要大量标注数据进行训练,这在实际应用中难以获得。
- 预训练和下游任务目标不一致: 现有的预训练方法往往只关注学习图结构特征,而下游任务的目标则更具体,例如节点分类或图分类。这种目标不一致会导致预训练知识无法有效迁移到下游任务。
GraphPrompt 的核心思想:
GraphPrompt 通过以下两个关键思想来解决上述问题:
- 统一任务模板: 将预训练和下游任务都统一到“子图相似度学习”的模板上。无论是链接预测、节点分类还是图分类,都可以通过计算子图之间的相似度来实现。
- 可学习的任务提示: 针对不同的下游任务,设计可学习的提示向量来指导子图表示学习过程中的 ReadOut 操作,从而实现任务特定的知识迁移。
GraphPrompt 的实现细节
子图相似度学习:
- 子图定义: 子图可以是节点周围的局部子图,也可以是整个图。
- 子图表示: 通过 GNN 模型将子图中的节点表示聚合为一个子图表示。GraphPrompt 可以使用不同的聚合方案,例如求和池化或注意力机制。
- 相似度计算: 使用余弦相似度函数计算子图表示之间的相似度。
- 任务目标: 将不同的任务目标转化为子图相似度学习的目标,例如:
- 链接预测: 计算两个节点的子图表示之间的相似度,预测它们之间是否存在链接。
- 节点分类: 计算一个节点的子图表示与各个类别原型子图表示之间的相似度,预测该节点的类别。
- 图分类: 计算一个图的子图表示与各个类别原型子图表示之间的相似度,预测该图的类别。
可学习的任务提示:
- 提示向量: 每个下游任务都对应一个可学习的提示向量,该向量用于指导子图表示学习过程中的 ReadOut 操作。
- 特征加权: 提示向量对子图中的节点表示进行特征加权,突出与任务相关的特征,抑制与任务无关的特征。
- 任务特定性: 不同的任务需要不同的提示向量,从而实现任务特定的知识迁移。
- 提示微调: 在下游任务中,GraphPrompt 使用提示微调来优化提示向量,而不需要微调预训练的 GNN 模型。
GraphPrompt 的优势:
- 减少标注数据需求: 通过预训练和提示,GraphPrompt 可以在标注数据稀缺的情况下取得良好的性能,例如少样本学习任务。
- 提升模型泛化能力: 统一的任务模板和可学习的提示使得 GraphPrompt 能够适应不同的下游任务,从而提升模型的泛化能力。
- 参数效率高: 与传统的微调方法相比,GraphPrompt 的下游任务只需要更新提示向量,从而大大降低了参数数量和计算量。
实验结果:
- 在五个公开数据集上,GraphPrompt 在节点分类和图分类任务中都取得了显著的性能提升,尤其是在少样本学习场景下。
- GraphPrompt 明显优于其他基线模型,包括端到端的 GNN 模型、图预训练模型和图提示模型。
- GraphPrompt 的参数数量和计算量都低于其他基线模型,从而提高了模型的效率。
未来研究方向:
- 探索更复杂的提示设计,例如注意力机制或神经网络。
- 研究 GraphPrompt 在其他图学习任务中的应用,例如图生成和图推理。
- 将 GraphPrompt 与其他图学习技术相结合,例如图增强和图注意力网络。