FLAG
Kong K, Li G, Ding M, et al. Flag: Adversarial data augmentation for graph neural networks[J]. arXiv preprint arXiv:2010.09891, 2020.
本文主要是提出了GNN的数据增强方法——FLAG,在增强模型鲁棒性的同时,研究其对模型准确率的影响。
我将FLAG方法也运用到了我自己设计的模型当中,在减少标准差的同时也让模型的准确率有了小幅度的提升。此外,FLAG方法使用起来灵活方便,以后的应用应该会十分的广泛。
问题的提出
OGB数据集是一个真实且大型的图数据集,它的出现给GNN带来了不小的挑战,很多模型都会在OGB数据集上出现overfit的情况,也就导致了acc一直上不去。
每当GNN发生一些问题,我们总是会去借鉴CNN中的一些策略。那么,CV中如何去缓解overfit呢?——就是采样数据增强(data augmentation)的方法。自然而然的就会想到,我们可不可以把数据增强推广到GNN上。
当然,很多人都针对GNN的数据增强提出了自己的观点。以往的观点关注的是图结构,一般是采用DropoutEdge的方法去进行正则化,以达到数据增强的效果。但是,也暴露出了灵活性、通用性、易用性、有效性等弊端。而本文主要是关注针对节点特征空间的数据增强方法。
提到节点特征空间,其他领域内很多都是采用对抗性数据增强的方法。它是通过对抗性扰动来增强数据并最终缓解overfit的。虽然它能够增强模型的鲁棒性,但是往往是以牺牲acc为代价的,如何在这种情况下提高acc,成为了一个需要研究的问题。
FLAG方法
FLAG(Free Large-scale Adversarial Augmentation on Graphs),是图上免费的大型对抗性数据增强方法,以缓解overfit。它主要是通过为输入节点特征增加基于梯度的对抗性扰动,来实现数据增强的。
FLAG方法的算法流程
算法思路详解:
其实弄明白之后发现其实思路和实现还是比较简单的。原来我们用SGD或Adam优化器来训练时,每个epoch就一次梯度下降、前向和反向传播,即N个epoch,每个epoch有1个SGD。现在FLAG方法的情况是,为了和之前的runtime大致相当,增加每个epoch中的PGD(投影梯度下降)次数,减少epoch,即N/M个epoch,每个epoch有M次PGD(当然也可以不减少epoch)。
在每个epoch中,我们首先定义一个和X形状相同且服从(-alpha, alpha)均匀分布的扰动矩阵pert,它带有梯度,然后作为扰动和X一起送入模型中训练M个step。在每个step中,我们根据pert的grad手动更新pert,然后清零pert的grad;每次都会得到loss所在的计算图的grad的1/M,不清零,将M次梯度进行累加。在M次step做完之后,统一对loss所在的计算图进行1次反向传播并更新参数。
更新pert的方法借鉴了对抗性训练的PGD方法,大家可以自己去阅读paper来感受。
FLAG方法的pytorch代码