读源码学算法之Monte Carlo Tree Search

最近研究新的算法有使用到Monte Carlo Tree Search,查了一些资料,参考几篇博客:
1.知乎:蒙特卡洛树搜索最通俗入门指南
2.知乎:AlphaGo背后的力量:蒙特卡洛树搜索入门指南
2.https://blog.csdn.net/caozixuan98724/article/details/103213795
4.https://blog.csdn.net/kurumi233/article/details/80307788
5.https://vgarciasc.github.io/mcts-viz/
在这里插入图片描述

我认真研究了一份代码,用蒙特卡洛实现井字棋人机对战的代码,非常感谢作者!
代码来自:https://github.com/cryer/monte-carlo-tree-search

实在没空写博客了,只能贴代码,我写了详细的注释,希望以后复习方便!

一、主代码 main.py
import numpy as np
from node import *
from search import MonteCarloTreeSearch
from tictactoe import TicTacToeGameState

def init(): #初始化,从一个3*3的全0的空白棋盘开始,模拟10000次,选一个最优的位置机器先落子
    state = np.zeros((3, 3))
    initial_board_state = TicTacToeGameState(state=state, next_to_move=1)
    root = MonteCarloTreeSearchNode(state=initial_board_state, parent=None)
    mcts = MonteCarloTreeSearch(root)
    best_node = mcts.best_action(1000)
    return best_node.state

#绘制棋盘格
def draw_chessboard(c_board):
    for i in range(3):
        print("\n{0:3}".format(i).center(8)+"|", end='')
        for j in range(3):
            if c_board[i][j] ==  0: print('_'.center(8), end='')
            if c_board[i][j] ==  1: print('X'.center(8), end='')
            if c_board[i][j] == -1: print('O'.center(8), end='')
    print("\n______________________________")

#人走棋,在某个空闲的位置落子(在3*3的矩阵中某个空闲的位置赋值为-1)
def get_human_action(state):
    location = input("Your move: ")
    location = [int(n, 10) for n in location.split(",")]
    x = location[0]
    y = location[1]
    move = TicTacToeMove(x, y, -1) 		#机器走是1,人走是-1
    if not state.is_move_legal(move):	#主要判断落子是否有问题
        print("invalid move")
        move = get_human_action(state)	#落子有问题则需要重新落子
    return move

#判断游戏是否已经结束(胜、负、平)
def is_game_over(state):
    if state.is_game_over():
        if state.game_result ==  1.0: print("You lose!")
        if state.game_result ==  0.0: print("Tie!")
        if state.game_result == -1.0: print("You Win!")
        return 1
    return 0

if __name__ == '__main__':
    c_state = init()                #机器先走一步
    draw_chessboard(c_state.board)  #绘制机器走的结果

    while is_game_over(c_state) != 1:	            #人走棋->机器落子->人走棋... 循环直到游戏结束
        # 1.人落子
        human_move = get_human_action(c_state)      #人落子
        c_state = c_state.move(human_move)          #落子的状态
        draw_chessboard(c_state.board)			    #绘制人落子棋盘

        # 2.机器搜索最好的位置落子
        board_state = TicTacToeGameState(state=c_state.board, next_to_move=1)
        root = MonteCarloTreeSearchNode(state=board_state, parent=None) #将当前位置作为新的根节点
        mcts = MonteCarloTreeSearch(root)           #从root开始进行MCT search
        best_computer_node = mcts.best_action(1000) #进行1000次模拟,然后返回最优落子位置,机器落子
        c_state = best_computer_node.state		    #机器落子后的状态
        draw_chessboard(c_state.board)		        #绘制机器落子后的棋盘格

二、蒙特卡洛搜搜代码 search.py
from node import MonteCarloTreeSearchNode

class MonteCarloTreeSearch:
    def __init__(self, node: MonteCarloTreeSearchNode):
        self.root = node

    def draw_chessboard(self,c_board):
        for i in range(3):
            print("\n{0:3}".format(i).center(8)+"|", end='')
            for j in range(3):
                if c_board[i][j] ==  0: print('_'.center(8), end='')
                if c_board[i][j] ==  1: print('X'.center(8), end='')
                if c_board[i][j] == -1: print('O'.center(8), end='')
        print("\n______________________________")

    def best_action(self, simulations_number):
        for i in range(simulations_number):         #模拟n次人机对战,找到最优落子位置
            v = self.expand_node()                  #root:当前棋局,在当前棋局下expand(走棋),先遍历孩子节点,再遍历孙子节点...
            # self.draw_chessboard(v.state.board)
            result = v.simulate_man_machine_game()  #从返回的扩展到的节点开始,模拟机器和人下棋,直到胜负平
            v.backpropagate(result)                 #将胜负情况,节点访问次数等回传到祖宗节点
        return self.root.best_child(c_param=0.)     #根据子节点的统计信息返回最优的落子位置

    def expand_node(self):
        current_node = self.root
        while not current_node.is_terminal_node():      #还没到达胜、负、平叶子节点
            if not current_node.is_fully_expanded():    #如果当前节点的所有子节点还没访问完
                return current_node.expand()            #扩展根节点 到其中一个子节点
            else: #如果当前所有子节点已都经访问过,那么找一个best child开始继续扩展
                current_node = current_node.best_child() #找出最优节点
        return current_node                              #如果已经到达叶子结点,直接返回

三、蒙特卡洛树节点类 node.py
import numpy as np
from collections import defaultdict
from tictactoe import *

class MonteCarloTreeSearchNode(object):
    def __init__(self, state: TicTacToeGameState, parent=None):
        self._number_of_visits = 0.
        self._results = defaultdict(int)
        self.state = state
        self.parent = parent
        self.children = []

    @property
    def untried_actions(self): #当前状态下,下一步可以走的所有方案存入_untried_actions
        if not hasattr(self, '_untried_actions'):
            self._untried_actions = self.state.get_legal_actions()
        return self._untried_actions

    @property
    def q(self):
        wins = self._results[self.parent.state.next_to_move]  #胜利次数
        loses = self._results[-1 * self.parent.state.next_to_move] #失败次数
        return wins - loses

    @property
    def n(self):
        return self._number_of_visits

    def expand(self):
        action = self.untried_actions.pop()     #走一个位置,这个位置就移除
        next_state = self.state.move(action)    #移动到这个位置,并切换next_to_move:1->-1 or -1->1

        child_node = MonteCarloTreeSearchNode(next_state, parent=self) #next_state->child_node
        self.children.append(child_node) #子节点加入到children
        return child_node

    def is_terminal_node(self): #胜、负、平都算terminal node
        return self.state.is_game_over()

    def simulate_man_machine_game(self): #从当前状态开始,人机随机落子下棋,直到胜负平
        current_rollout_state = self.state
        while not current_rollout_state.is_game_over():
            possible_moves = current_rollout_state.get_legal_actions()
            action = self.rollout_policy(possible_moves)
            current_rollout_state = current_rollout_state.move(action)
        return current_rollout_state.game_result

    def backpropagate(self, result): #从当前节点到根节点递归的回传对战信息
        self._number_of_visits += 1.
        self._results[result] += 1.
        if self.parent:
            self.parent.backpropagate(result)

    def is_fully_expanded(self):
        return len(self.untried_actions) == 0

    def best_child(self, c_param=1.4): #根据统计信息获取最优的落子位置
        choices_weights = [
            (c.q / (c.n)) + c_param * np.sqrt((2 * np.log(self.n) / (c.n)))
            for c in self.children
        ]
        return self.children[np.argmax(choices_weights)]

    def rollout_policy(self, possible_moves): #随机产生一个可以落子的位置
        return possible_moves[np.random.randint(len(possible_moves))]

四、井字棋状态类 tictactoe.py
import numpy as np

class TicTacToeMove(object):
    def __init__(self, x_coordinate, y_coordinate, value):
        self.x_coordinate = x_coordinate    #棋盘X坐标
        self.y_coordinate = y_coordinate    #棋盘Y坐标
        self.value = value                  #棋盘的棋子:1:X,-1:o, 0:空闲

class TicTacToeGameState(object):
    x = +1  #机器的棋子X用+1表示
    o = -1  #人的棋子O用-1表示

    def __init__(self, state, next_to_move=1):
        self.board = state
        self.board_size = state.shape[0]
        self.next_to_move = next_to_move

    @property
    def game_result(self): #机器的视角: 胜、负、平
        rowsum = np.sum(self.board, 0) #按列求和
        colsum = np.sum(self.board, 1) #按行求和
        diag_sum_tl = self.board.trace()        #主对角线求和
        diag_sum_tr = self.board[::-1].trace()  #副对角线求和
        
        #1.行,列,对角线之和 = 3,机器获胜,返回1
        if any(rowsum == self.board_size) or any(colsum == self.board_size) or diag_sum_tl == self.board_size or diag_sum_tr == self.board_size:
            return 1.
        #2.行,列,对角线之和 = -3,机器失败,返回-1
        elif any(rowsum == -self.board_size) or any(colsum == -self.board_size) or diag_sum_tl == -self.board_size or diag_sum_tr == -self.board_size:
            return -1.
        #3.棋盘已经填满,但未分出胜负,平局,返回0
        elif np.all(self.board != 0):
            return 0.
        #4.正在走棋,还没到终止状态
        else:
            return None

    def is_game_over(self):
        return self.game_result != None

    def is_move_legal(self, move):
        # next_to_move:1:机器,-1:人,若是机器,那么move.value = 1:X, 反之亦然
        if move.value != self.next_to_move:
            return False
        # 落子是否越界
        if not (move.x_coordinate < self.board_size and move.x_coordinate >= 0):
            return False
        if not (move.y_coordinate < self.board_size and move.y_coordinate >= 0):
            return False
        # 落子的位置是否没有棋子
        return self.board[move.x_coordinate, move.y_coordinate] == 0

    def move(self, move): #在(move.x_coordinate, move.y_coordinate)上放置move.value,并切换谁下次落子
        new_board = np.copy(self.board)
        new_board[move.x_coordinate, move.y_coordinate] = move.value 
        next_to_move = TicTacToeGameState.o if self.next_to_move == TicTacToeGameState.x else TicTacToeGameState.x
        return TicTacToeGameState(new_board, next_to_move)

    def get_legal_actions(self): #空闲的位置均可落子
        indices = np.where(self.board == 0)
        return [TicTacToeMove(coords[0], coords[1], self.next_to_move) for coords in list(zip(indices[0], indices[1]))]

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Researcher-Du

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值