Q-learning小游戏实例

看了莫烦python的一位寻宝游戏,为了更好的掌握Q-table的使用,自己写了二维地图寻宝游戏,map 10*10,玩家initial location随机产生,treasure位于地图中央,完整代码见github。可以直接运行(只要环境对)

import 相应的库

由于这里没有玩很high level的游戏,所以没有使用gym,tensorflow等库,只是简单的numpy,panda等

import numpy as np
import pandas as pd
import time
import random
import matplotlib.pyplot as plt

超参数

N_states = 9  # the length of two-dimensional world: N_states*N_states
POS = 40 # the position of treasure (4,4)
ACTION = ["left", "right", "top", "bottom"] #avaliable actions
alpha = 0.1 #learning rate
gamma = 0.9 #discounted factor
episodes = 20 #the maximal number of iterations
interval = 0.3 #time required for each step
greedy = 0.9 # the greedy factor

寻宝过程探险者只有上下左右四个方向,ACTION完全是自由发挥的。

建立Q表

def build_table(nStates, actions):
    '''
    DataFrame是一个表格型的数据结构,由行和列组成,分别有行索引和列索引,且每列可以是不同类型的值。
    创建DataFrame对象
    创建的时候,可以通过参数index和columns分别指定行索引和列索引
    
    1.传入一个numpy的多维数组对象
        
    2.传入一个字典内部包含列表,字典内的列表是等长的, 字典的key默认为列索引
    '''
    table = pd.DataFrame(
            np.zeros((nStates*nStates, len(actions))),
           columns = actions)
    # q_table:
    """
        left  right  top  bottom
    0    0.0    0.0  0.0     0.0
    1    0.0    0.0  0.0     0.0
    2    0.0    0.0  0.0     0.0
    3    0.0    0.0  0.0     0.0
    4    0.0    0.0  0.0     0.0
    5    0.0    0.0  0.0     0.0
    6    0.0    0.0  0.0     0.0
    7    0.0    0.0  0.0     0.0
    8    0.0    0.0  0.0     0.0
    9    0.0    0.0  0.0     0.0
    """
    return table

根据状态和Q-table选择下一步的action

有随机的概率不使用argmax来选择行为,给sample增加随机性

def chooseActions(state, table):
    '''
    loc通过行和列 的名字获取值
    iloc通过下标获取值
    '''
    #print(state)
    actions = table.iloc[state,:] #choose corresponding values of actions about this state
    if(np.random.uniform() >= greedy) or (actions.all() == 0):#not greedy or all the actions have not been sampled
        actName = np.random.choice(ACTION);
    else:
        actName = actions.argmax()
    return actName

与环境的交互

actor每采取一个行动,我们更新他在地图上相应的位置,给他相应的reward,这里的reward是我调整过的,actor存活时间越久,距离中心位置越远,得到的惩罚越多,只有到达终点才能获得positive reward。当然这样的reward不一定是最优的,不同的游戏reward各种各样,需要慢慢尝试

#Interaction with environment and get feedback
'''
S: current state
A: name of action
'''
def feedback(S, A):
    # This is how agent will interact with the environment
    if A == "right": #move right
        if S == POS - 1:#POS is the destination, the center of the two-dimensional world
            S_ = "terminal" #only when we approach the treasure can we get the reward, 24
            R = 10
        elif S % N_states == N_states - 1: #on the right border
            S_ = S; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
        else:
            S_ = S + 1; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
    elif A == "left":
        if S == POS + 1:#12 is the destination, the center of the two-dimensional world
            S_ = "terminal" #only when we approach the treasure can we get the reward, 24
            R = 10
        elif S % N_states == 0: #on the left border
            S_ = S; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
        else:
            S_ = S - 1; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
    elif A == "top":
        if S == POS - N_states:#12 is the destination, the center of the two-dimensional world
            S_ = "terminal" #only when we approach the treasure can we get the reward, 24
            R = 10
        elif int(S/N_states) == N_states - 1: # on the top border
            S_ = S; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
        else:
            S_ = S + N_states; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
    elif A == "bottom":
        if S == POS + N_states:#12 is the destination, the center of the two-dimensional world
            S_ = "terminal" #only when we approach the treasure can we get the reward, 24
            R = 10
        elif int(S/N_states) == 0:#on the bottom
            S_ = S; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
        else:
            S_ = S - N_states; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
    return S_, R

更新环境

打印地图信息,actor所在位置,已经经过了多少次训练等必要的信息。

'''
S:state int: the position of explorer
episode: episode
stepCnt: total step in this episode
'''
def updateEnvironment(S, episode, stepCnt):
    #this is how the environment be updated
    envList =[ ['-' for i in range(N_states)] for j in range(N_states) ]
    envList[int(N_states/2)][int(N_states/2)] = 'T'
    #envList = ['-']*(N_states) + ['T']   # '---------T' our environment
    if S == "terminal":
         interaction = 'Episode %s: total_steps = %s' % (episode+1, stepCnt)
         print('\r{}'.format(interaction), end='')
         time.sleep(2)
         print('\r                                \n', end='')
    else:
         envList[int(S/N_states)][S%N_states] = 'o'
         #interaction = ''.join(envList)
         for i in range(N_states):
             print(envList[i])         
         print("\n")
         time.sleep(interval)

Q-table的更新

完全依照伪代码进行的参数更新

#most important part, the update of Q-table
def rl():
    x = []
    cnts = []
    QTable = build_table(N_states, ACTION);
    for episode in range(episodes):
        x.append(episode)
        stepCnt = 0
        S = random.randint(0,N_states*N_states-1) #the initial position
        terminated = False #whether terminated
        updateEnvironment(S, episode, stepCnt)
        while not terminated:
            
            a = chooseActions(S, QTable);#determine the next action according to the state and Q-table
            S_, R = feedback(S, a) #play the game, getting reward and the next state
            QPredict = QTable.loc[S, a] # the predict value of state-action pair
            if S_ != "terminal":
                QTarget = R + gamma * QTable.iloc[S_,:].max() #actual value of state-action pair
            else:
                QTarget = R #terminated
                terminated = True    # terminate this episode
            
            #updating Q-table, execting tempory difference method
            QTable.loc[S,a] += alpha * (QTarget - QPredict) 
            S = S_
            
            stepCnt += 1
            updateEnvironment(S, episode, stepCnt)
        cnts.append(stepCnt)
    plt.plot(x, cnts, color = "grey", linestyle = '--', marker = '^')
    plt.xlabel("episode")
    plt.ylabel("Total steps")
    return QTable

主函数

#begin to play
if __name__ == "__main__":
    q_table = rl()
    print('\r\nQ-table:\n')
    print(q_table)
  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值