模拟1000次后,在根节点下面选出一个最好的子节点
(Clion运行c++时要把其他无关文件注释掉,不然运行main程序报错,大概会引用到其他文件的函数名相同的上)
C++实现:
//
// on 2017/11/19.
//
#include <iostream>
#include <assert.h>
#include <stack>
#include <limits>
#include <cmath>
using namespace std;
const double EPSILON = 1e-6;
class UCTreeNode {
private:
UCTreeNode *vpChildren_i[5];
bool isLeaf_i=true;
double nVisits_i=0;
double totValue_i=0;
int childNum=5;
int selectAction() {
assert(!isLeaf_i); //不是叶子结点
int selected = 0;
double bestValue = -numeric_limits<double>::max();
for (int k = 0; k < childNum; ++k) //遍历n个孩子结点
{
UCTreeNode *pCur = vpChildren_i[k]; // ptr to current child node
//assert(0 != pCur); //孩子结点不是空
double uctValue = pCur->totValue_i / (pCur->nVisits_i + EPSILON) +
sqrt(log(nVisits_i + 1) / (pCur->nVisits_i + EPSILON));
if (uctValue >= bestValue) {
selected = k;
bestValue = uctValue;
}
} // for loop
return selected; //找出uct最大的结点返回
} // selectAction
void expand() {
if (!isLeaf_i)
return;
//isLeaf_i = false;
for (int k = 0; k < childNum; ++k)
vpChildren_i[k] = new UCTreeNode();
isLeaf_i=false;
} // expand
int rollOut() //返回最后的结果值
{
return rand() % 2;
} // rollout
void updateStats(int value) {
nVisits_i++; // increment the number of visits
totValue_i += value; // update the total value for all visits
}
public:
UCTreeNode() {
for (int k = 0; k < childNum; ++k) {
vpChildren_i[k] = 0;
}
} // default constructor
UCTreeNode(const UCTreeNode &tree) {
if (isLeaf_i) {
return;
}
for (int k = 0; k < childNum; ++k) {
assert(0 != tree.vpChildren_i[k]);
vpChildren_i[k] = new UCTreeNode(*tree.vpChildren_i[k]);
}
} // copy constructor
bool isLeaf() const {
return isLeaf_i;
}
void iterate() {
stack<UCTreeNode *> visited;
UCTreeNode *pCur = this;
visited.push(this);
int action = 0; // next selected action
while (!pCur->isLeaf()) {
action = pCur->selectAction();
pCur = pCur->vpChildren_i[action];
visited.push(pCur);
}
pCur->expand();
action = pCur->selectAction();
pCur = pCur->vpChildren_i[action];
visited.push(pCur);
double value = rollOut();
while (!visited.empty()) {
pCur = visited.top();
// get the current node in the path
pCur->updateStats(value); // update statistics
visited.pop();
pCur->Value();
// remove the current node from the stack
}
} // iterate
int bestAction() { //返回utc最大的那个值
int selected = 0;
double bestValue = -numeric_limits<double>::max();
for (int k = 0; k < childNum; ++k) {
UCTreeNode *pCur = vpChildren_i[k]; // ptr to current child node
assert(0 != pCur);
double expValue = pCur->totValue_i / (pCur->nVisits_i + EPSILON);
expValue += static_cast<double>(rand()) * EPSILON / RAND_MAX;
if (expValue >= bestValue) {
selected = k;
bestValue = expValue;
}
} // for loop
return selected;
} // bestAction
void Value() const {
cout << totValue_i << "/" << nVisits_i << endl;
}
};
int main(){
UCTreeNode tree;
for(int k=0; k<1000; ++k)
{
tree.iterate();
cout << endl;
}
cout << endl;
int bestAction = tree.bestAction();
cout << "Best Action: " << bestAction << std::endl;
return 0;
}
java实现:
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
public class TreeNode {
static Random r = new Random();
static int nActions=5; //五个步骤,也就是五个子节点
static double epsilon =1e-6;
TreeNode[] children; //该结点的五个子节点
int nVisits,totValue; //总的访问次数,总胜负次数
public boolean isLeaf(){ //是不是下面没有子结点
return children==null;
}
public TreeNode select(){ //按照uct公式计算每个子节点,找出最大值,返回该结点。
TreeNode selected=null;
double bestValue =Double.
;
for (TreeNode c:children){ //计算每个孩子的uct的值
double uctValue =c.totValue/ (c.nVisits+epsilon)+
Math.sqrt(Math.log(nVisits+1)/(c.nVisits+epsilon))+r.nextDouble()*epsilon;
if(uctValue>bestValue){
selected=c;
bestValue=uctValue;
}
}
return selected;
}
public void expand(){ //扩展当前结点的5个孩子结点
children=new TreeNode[nActions]; //扩展当前结点的子节点,扩展5个孩子
for(int i=0;i<nActions;i++){
children[i]=new TreeNode(); //对于一个类的数组,中间每一个都要进行初始化
}
}
public void selectAction(){ //这里是最关键的函数
List<TreeNode> visited =new LinkedList<>(); //存储访问路径上面的结点
TreeNode cur=this; //当前结点
System.out.print("当前结点为:"+cur.totValue+"/"+cur.nVisits+" \n ");
visited.add(this);
while(!cur.isLeaf()){ //如果当前结点不是最底层节点
cur=cur.select(); //往下走,把当前结点设置为uct最大的那个子结点
visited.add(cur); //把选择过的结点都加到visited队列里面
System.out.print("下一级结点是"+cur.totValue+"/"+cur.nVisits+" ");
}
System.out.print("\n");
cur.expand(); //这里不是很明白为什么要扩展5个
TreeNode newNode = cur.select();
visited.add(newNode);
int value=rollOut();
for (TreeNode node :visited){ //搜索路径上面的每个结点都要重新更新值
//对于n个参与者的游戏需要其他的逻辑
node.updateState(value);
}
}
public int rollOut(){ //随机返回tn节点的胜负,这里可以有更加优化的算法
return r.nextInt(2); //该方法的作用是生成一个随机的int值,该值介于[0,n)的区间,这里也就是0或者1
}
public void updateState(double value){
nVisits++; // 该节点的访问次数+1
totValue+=value; //该节点的胜利次数+1
}
public int arity(){ //返回有几个孩子
return children==null?0:children.length;
}
}
class m{
public static void main(String[] args) {
TreeNode tree=new TreeNode();
tree.totValue=0;
tree.nVisits=0;
int n=0;
while(n++<1000) {
tree.selectAction();
}
System.out.println(tree.select().totValue+"/"+tree.select().nVisits);
}
}