What can neural networks reason about?ICLR2020高分论文解析
论文地址:论文地址
最近在看这篇论文作者MIT研究生keyulu Xu的2021ICLR最高分论文时,发现很多地方与此片论文有联系,所以先回头来看看他2020年的论文。
论文干了什么
此篇论文正式定义此算法比对,并得出样本复杂度界限,并且该界限随着更好的比对而降低。并且作者展示GNNs和DP(动态规划)问题的对比,发现GNNs很有希望取解决DP等问题,在几个推理任务上,(Summary statistics, Relational argmax, Dynamic programming, NP-hard problem)提出的理论得到了经验结果的支持。
其中GNN在和DP(动态规划)的问题的算法有很高的相似性,因此能够有效采样地学习这些任务(有效采样:因为我们机器学习时从一个支持及上面采样一个训练集来做训练,所以采样的样本越少效果越好就说明这个模型good,这就是后面的PAC学习大概要说的东西)
上图将GNN的计算结构和Bellman-Ford算法对比,可以看出GNN中只需要MLP去学习一个简单的推理步骤,两个计算结构就基本相似了,那如果用MLP直接来模拟Bellman-Ford算法的话,还要去学习一个for循环,通过论文中的定理3.5也说明了for循环是一个非常难学习的算法步骤。
所以,与不正确对齐的神经网络相比,能够更好地与正确的推理过程(算法解决方案)保持一致的神经网络可以更轻松地学习推理任务。换句话就是说,针对不同的任务应该选择和此任务解决算法保持一致的神经网络框架何以使神经网络更好的学习此任务。
PAC learning and sample complexity
计学习理论中最基本的是概率近似正确(Probably Approximately Correct)简称PAC。
这里我对sample complexity理解为:最少有多少的训练样本能学习到目标概念。
所以我们之间看到定义3.3
具体的就是PAC学习的定义,包括PAC Identify(PAC辨识)和PAC Learnable(PAC 可学习),那我们可以看到||f(x) - g(x)||为误差,那就是误差小于ε的概率大于等于σ(误差小于ε的概率是比较大的),那么就说明这个算法是PAC可学习的。
样本复杂度是满足上式最小的M(M是样本个数),在学习算法A之下,function g就是(M, ε, σ)可学习的。
基于PAC定理,作者定义了算法比对的数字度量(这里algorithmic alignment 我就统一叫成算法比对了呀。)
最后的这个M就是最大的一个模块的采样复杂度*n,就是这个模型的Algorithmic alignment Value.
一个好的算法,应该是有一个小的M,也就是说所有算法步骤fi,去模拟 的算法步骤g都是非常容易去学习的。
定理3.6说明了样本复杂度界限随算法对齐值M的增加而增加,这也不难理解, M越小说明对齐(匹配)的越好。
定理3.6我下来再去研究一下,不是很懂 ,提出了三个假设,分别是算法稳定性,序列学习(对前面较好的学习效果是有记忆保存的),Lipschizness约束。
预测神经网络的推理能力
上图是对4个复杂任务的测试集准确率。(a)Summary statistics,除了MLP其他模型都泛化了,(b)Relaational argmax, Deep Set失败。©Dynamic programming 只有GNN在有足够的循环次数下泛化了。(d)NP-hard problem,GNNs失败,但是NES泛化了。
SO, GNN > MLP
其实我对这个实验设置挺感兴趣的,如何用GNN去做DP问题的实现,看了作者的github代码和附录,这里的GNN应该就是fig2中的那样,是一个有点怪的GNNs,不过这样写也没得毛病,就是for循环加MLP,MLP学习的是每个node和邻居节点的信息。不过GNN是需要构建graph使用邻居信息的,所以上面4个复杂任务都有位置信息,考虑了距离远近,我觉得应该是一张完全图,不过每条边都有距离信息。
然后,就是对4个复杂任务的具体描述了。。。
这张图表是各个模型在DP任务下,随着训练集采样个数增加各模型测试集准确率的变化。
如果一个神经网络模型和算法解法很匹配的话,那么该神经网络的测试集准确率就会非常快速的上升,比如训练集采样个数在40000 ~ 80000的阶段GNN4提升了23%, DeepSet提升了0.2%,那就说明GNN比DeepSet更加匹配DP问题的算法框架!
总结
- 这篇论文是正式了解神经网络如何学习推理的第一步,提出了算法框架对比的概念,使神经网络的设计可以学习一些推理范例。比如DP,并且用实验证明了其分析的确实是正确的。
- 对于一个神经网络来说学习for训练时相当难的。
- 通过实验我们可以看出GNN的推理能力是很强的,因为GNN的框架概括了许多推理任务。
最后不得不佩服作者的深厚的算法基础和数学能力