笔记整理:刘治强,浙江大学硕士生
链接:https://dl.acm.org/doi/10.1145/3543873.3587596
1. 动机
尽管GNN在链路预测任务中表现出较高的准确性,但它并不是设计用于归纳设置下的链路预测。此外,由于对图数据的依赖性,GNN在大规模工业部署中显示出相当高的推理延迟。
尽管由于缺乏对图拓扑的访问,MLP在节点分类任务上的归纳偏差比GNN小得多,但由于其低推理延迟,它们在工业规模应用程序中获得了很高的欢迎。先前的观察结果使得最近的研究人员在节点分类任务中利用从教师GNN到学生MLP的跨模型知识蒸馏。
鉴于在链接预测,特别是归纳链接预测方面的知识蒸馏仍是一个未被探索的领域,本文致力于在转导链接预测和归纳链接预测任务中加速推理。此外,由于链路预测可能涉及对源节点和多个上下文节点类型之间的多种关系进行推理。
2. 方法
提出的框架:Graph2Feat
知识蒸馏:将知识从繁重的教师模型提炼为更轻量级的学生模型。
本文从教师GNN模型生成软预测 ,然后通过最小化损失函数 训练学生模型MLP去匹配对应的软预测 。最终得到的学生模型的损失函数为: