C++&java实现的一颗纯MCTS

模拟1000次后,在根节点下面选出一个最好的子节点
(Clion运行c++时要把其他无关文件注释掉,不然运行main程序报错,大概会引用到其他文件的函数名相同的上)


C++实现:

//
//  on 2017/11/19.
//
#include <iostream>
#include <assert.h>
#include <stack>
#include <limits>
#include <cmath>
using namespace std;
const double EPSILON = 1e-6;
class UCTreeNode {
private:
    UCTreeNode *vpChildren_i[5];
    bool isLeaf_i=true;
    double nVisits_i=0;
    double totValue_i=0;
    int childNum=5;

    int selectAction() {
        assert(!isLeaf_i); //不是叶子结点
        int selected = 0;
        double bestValue = -numeric_limits<double>::max();
        for (int k = 0; k < childNum; ++k)  //遍历n个孩子结点
        {
            UCTreeNode *pCur = vpChildren_i[k]; // ptr to current child node
            //assert(0 != pCur); //孩子结点不是空
            double uctValue = pCur->totValue_i / (pCur->nVisits_i + EPSILON) +
                              sqrt(log(nVisits_i + 1) / (pCur->nVisits_i + EPSILON));
            if (uctValue >= bestValue) {
                selected = k;
                bestValue = uctValue;
            }

        } // for loop
        return selected; //找出uct最大的结点返回

    } // selectAction

    void expand() {
        if (!isLeaf_i)
            return;
        //isLeaf_i = false;
        for (int k = 0; k < childNum; ++k)
            vpChildren_i[k] = new UCTreeNode();
        isLeaf_i=false;


    } // expand

    int rollOut()  //返回最后的结果值
    {
        return rand() % 2;


    } // rollout

    void updateStats(int value) {
        nVisits_i++;         // increment the number of visits
        totValue_i += value; // update the total value for all visits
    }

public:
    UCTreeNode() {
        for (int k = 0; k < childNum; ++k) {
            vpChildren_i[k] = 0;
        }

    }  // default constructor
    UCTreeNode(const UCTreeNode &tree) {
        if (isLeaf_i) {
            return;
        }
        for (int k = 0; k < childNum; ++k) {
            assert(0 != tree.vpChildren_i[k]);
            vpChildren_i[k] = new UCTreeNode(*tree.vpChildren_i[k]);
        }

    } // copy constructor

    bool isLeaf() const {
        return isLeaf_i;
    }

    void iterate() {
        stack<UCTreeNode *> visited;
        UCTreeNode *pCur = this;
        visited.push(this);
        int action = 0; // next selected action

        while (!pCur->isLeaf()) {
            action = pCur->selectAction();
            pCur = pCur->vpChildren_i[action];
            visited.push(pCur);
        }
        pCur->expand();
        action = pCur->selectAction();
        pCur = pCur->vpChildren_i[action];
        visited.push(pCur);

        double value = rollOut();
        while (!visited.empty()) {
            pCur = visited.top();
            // get the current node in the path
            pCur->updateStats(value);   // update statistics
            visited.pop();
            pCur->Value();
            // remove the current node from the stack
        }

    } // iterate

    int bestAction() { //返回utc最大的那个值
        int selected = 0;
        double bestValue = -numeric_limits<double>::max();
        for (int k = 0; k < childNum; ++k) {
            UCTreeNode *pCur = vpChildren_i[k]; // ptr to current child node
            assert(0 != pCur);
            double expValue = pCur->totValue_i / (pCur->nVisits_i + EPSILON);
            expValue += static_cast<double>(rand()) * EPSILON / RAND_MAX;
            if (expValue >= bestValue) {
                selected = k;
                bestValue = expValue;
            }
        } // for loop
        return selected;
    } // bestAction

    void Value() const {
        cout << totValue_i << "/" << nVisits_i << endl;
    }
};

int main(){
    UCTreeNode tree;

    for(int k=0; k<1000; ++k)
    {
        tree.iterate();
        cout << endl;
    }
    cout << endl;
    int bestAction = tree.bestAction();
    cout << "Best Action: " << bestAction << std::endl;
    return 0;
}

java实现:


import java.util.LinkedList;
import java.util.List;
import java.util.Random;

public class TreeNode {
    static Random r = new Random();
    static int nActions=5;    //五个步骤,也就是五个子节点
    static double epsilon =1e-6;

    TreeNode[] children;        //该结点的五个子节点
    int nVisits,totValue;     //总的访问次数,总胜负次数

    public boolean isLeaf(){  //是不是下面没有子结点
        return children==null;
    }

    public TreeNode select(){   //按照uct公式计算每个子节点,找出最大值,返回该结点。
        TreeNode selected=null;
        double bestValue =Double.

        ;
        for (TreeNode c:children){      //计算每个孩子的uct的值
            double uctValue =c.totValue/ (c.nVisits+epsilon)+
                    Math.sqrt(Math.log(nVisits+1)/(c.nVisits+epsilon))+r.nextDouble()*epsilon;
            if(uctValue>bestValue){
                selected=c;
                bestValue=uctValue;
            }
        }
        return selected;
    }

    public void expand(){  //扩展当前结点的5个孩子结点
         children=new TreeNode[nActions];  //扩展当前结点的子节点,扩展5个孩子
         for(int i=0;i<nActions;i++){
             children[i]=new TreeNode();  //对于一个类的数组,中间每一个都要进行初始化
         }
    }


    public void selectAction(){  //这里是最关键的函数
        List<TreeNode> visited =new LinkedList<>(); //存储访问路径上面的结点
        TreeNode cur=this; //当前结点
        System.out.print("当前结点为:"+cur.totValue+"/"+cur.nVisits+" \n ");
        visited.add(this);
        while(!cur.isLeaf()){  //如果当前结点不是最底层节点
            cur=cur.select();  //往下走,把当前结点设置为uct最大的那个子结点
            visited.add(cur);   //把选择过的结点都加到visited队列里面
            System.out.print("下一级结点是"+cur.totValue+"/"+cur.nVisits+"  ");
        }
        System.out.print("\n");
        cur.expand();         //这里不是很明白为什么要扩展5个
        TreeNode  newNode = cur.select();
        visited.add(newNode);
        int value=rollOut();
        for (TreeNode node :visited){    //搜索路径上面的每个结点都要重新更新值
            //对于n个参与者的游戏需要其他的逻辑
            node.updateState(value);
        }
    }

    public int rollOut(){  //随机返回tn节点的胜负,这里可以有更加优化的算法
        return r.nextInt(2);  //该方法的作用是生成一个随机的int值,该值介于[0,n)的区间,这里也就是0或者1
    }

    public void updateState(double value){
        nVisits++;   // 该节点的访问次数+1
        totValue+=value; //该节点的胜利次数+1
    }
    public int arity(){   //返回有几个孩子
        return children==null?0:children.length;
    }
}

class m{
        public static void main(String[] args) {

        TreeNode tree=new TreeNode();
        tree.totValue=0;
        tree.nVisits=0;
        int n=0;
        while(n++<1000) {
            tree.selectAction();
        }
        System.out.println(tree.select().totValue+"/"+tree.select().nVisits);
    }
}
  • 2
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值