Muzero算法是什么?
Muzero建立在alphaZero算法的搜索能力以及基于搜索的策略迭代算法之上,同时联合了一个学习model在训练的过程中,极大的扩展了该学习算法的应用场景。
它主要是将mentor Calor Tree Search算法、hidden state value equivalence思想,以及Deep Neural Network 相结合,创造一个更加general的算法来进行强化学习的训练,一句话总结就是 Muzero算法是一个不需要environment dynamic先验,更加general的model-based的强化学习算法。
Muzero算法的算法过程是什么样的?
1、hidden state value equivalence思想
强化学习一直以来存在的一个大问题就是如何处理training space与 application space的状态的offset问题,之前就有很多人提出利用transfer learning 或者 Domain Adaptation来解决两个space之间的offset,但是hidden state value equivalence的思想是从state本身出发,在训练整个网络的时候并不是直接用训练场景的原始state作为输入,而是把原始的state通过一个神经网络进行转换,变换成hidden state。举个例子,人在实验室里面或者一个工厂里面走路(planning)的时候能够保证自己不碰撞到任何障碍物,而实验室的场景和工厂大不相同,也就是说真正决定人不会碰撞到障碍物不是一个具体场景,更像场景中一个抽象的结构信息,这里的hidden state就可以类比成这个抽象的结构信息。这样做的好处显而易见,就是能够提高网络的泛化能力以及网络的收敛性。只要hidden state对应的value和真实场景中对应的value是一致的,那么我们就能说hidden state可以在该任务中替代 真实 state。
2、Muzero算法过程
1、Muzero model包含三个部分(对应于上图的A):
(1)representation function:将场景中的原始state转换成hidden state
(2)dynamics function:给定一个hidden state S(k-1)和一个挑选好的 action a(k),dynamics 计算出一个即时reward r(k)和下一个hidden state S(k)
(3)prediction function:给定dynamics计算出来的S(k),prediction计算出对应的policy p(k)以及对应的value Q(k)
2、Muzero算法是如何采集训练数据的?(对应于上图的B)
每隔一个时间tick,采集到原始state 观测 o(t),此时会采用蒙特卡罗搜索树的搜索策略,也就是从root node到该node访问次数最多的这个node所产生的action a(t+1),在环境中执行这个action之后会获得一个observation o(t+1)以及一个即时的reward r(t+1),然后这个组合 [o(t), a(t+1), o(t+1), r(t+1)] 就会被存储在replay buffer中,直到该个game结束,我们把这一个个组合连在一起,称为一个trajectory。
3、Muzero算法是如何利用采集好的数据进行训练的?
如上图中的C,
(1)算法首先从采集的众多trajectory中随机sample 一条 ,然后取前k个时间的数据作为初始的state,利用presentation 网络首先计算一个初始的hidden state s(0);
(2)prediction 网络输入的hidden state 计算出对应的policy 和 value;
(3)dynamics 网络预测出对应的reward和下一个hidden state;
如此反复循环k steps进行训练,所以一组训练用到的数据个数是t+k个
4、Muzero具体的网络结构:
Muzero算法主要有三个网络,分别是:
(1)policy network:对应于下图loss function的第三项,主要是为了minimum policy network输出的policy与 MCTS输出的policy之间的error
(2)value network:对应于下图loss function的第二项,主要是为了minimum value network输出的value与指定的value之间的error
(3)reward network:对应于下图loss function的第一项,主要是minimum预测的reward与真实在采集数据获得的reward之间的error。
最后Muzero的loss function:
在训练的时候也用到了MCTS算法,不仅仅是在采集数据的时候用到了,policy network产生的π也是用到了MCTS算法。
注意:presentation网络只在初始的时候产生一个hidden state s(0),后面的hidden state全部是由dynamics网络预测出来的。
Muzero算法依赖的蒙特卡罗搜索树的原理是什么?
网上关于蒙特卡罗搜索树的原理讲解太多了,下面我贴一个知乎上写的很好的一篇文章:
https://zhuanlan.zhihu.com/p/53948964
参考文献:
[1] Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K., Sifre, L., Schmitt, S., … & Silver, D. (2020). Mastering atari, go, chess and shogi by planning with a learned model. Nature, 588(7839), 604-609.
[2] Silver, D., Hasselt, H., Hessel, M., Schaul, T., Guez, A., Harley, T., … & Degris, T. (2017, July). The predictron: End-to-end learning and planning. In International Conference on Machine Learning (pp. 3191-3199). PMLR.
[3] https://zhuanlan.zhihu.com/p/53948964