©PaperWeekly 原创 · 作者|桑运鑫
学校|上海交通大学硕士生
研究方向|图神经网络应用
最近对一个大规模的图训练嵌入,发现相关的中文资料还是很欠缺的,把自己踩的一些坑记下来。本文主要针对 DGL [1] 和 PyTorch [2] 两个框架。
训练大规模图
对于大规模图不能像小图一样把整张图扔进去训练,需要对大图进行采样,即通过Neighborhood Sampling
方法每次采样一部分输出节点,然后把更新它们所需的所有节点作为输入节点,通过这样的方式做 mini-batch 迭代训练。具体的用法可以参考官方文档中的 Chapter 6: Stochastic Training on Large Graphs [3]。
但是 GATNE-T [4] 中有一种更有趣的做法,即只把 DGL 作为一个辅助计算流的工具,提供Neighborhood Sampling
和Message Passing
等过程,把Node Embedding
和Edge Embedding
等存储在图之外,做一个单独的Embedding
矩阵。每次从 dgl 中获取节点的id
之后再去Embedding
矩阵中去取对应的 embedding 进行优化,以此可以更方便的做一些优化。
缩小图规模
从图的Message Passing
过程可以看出,基本上所有的图神经网络的计算都只能传播连通图的信息,所以可以先用 connected_componets [5] 检查一下自己的图是否是连通图。如果分为多个连通子图的话,可以分别进行训练或者选择一个大小适中的 component 训练。
如果图还是很大的话,也可以对图整体做一次Neighborhood Sampling
采样一个子图进行训练。
减小内存占用
对于大规模数据而言,如何在内存中存下它也是一件让人伤脑筋的事情。这时候采用什么样的数据结构存储就很关键了。首先是不要使用原生的list
,使用np.ndarray
或者torch.tensor
。尤其注意不要显式的使用set
存储大规模数据(可以使用set
去重,但不要存储它)。
注意:四种数据结构消耗的内存之间的差别(比例关系)会随着数据规模变大而变大。
其次就是在PyTorch
中,设置DataLoader
的num_workers