看了莫烦python的一位寻宝游戏,为了更好的掌握Q-table的使用,自己写了二维地图寻宝游戏,map 10*10,玩家initial location随机产生,treasure位于地图中央,完整代码见github。可以直接运行(只要环境对)
import 相应的库
由于这里没有玩很high level的游戏,所以没有使用gym,tensorflow等库,只是简单的numpy,panda等
import numpy as np
import pandas as pd
import time
import random
import matplotlib.pyplot as plt
超参数
N_states = 9 # the length of two-dimensional world: N_states*N_states
POS = 40 # the position of treasure (4,4)
ACTION = ["left", "right", "top", "bottom"] #avaliable actions
alpha = 0.1 #learning rate
gamma = 0.9 #discounted factor
episodes = 20 #the maximal number of iterations
interval = 0.3 #time required for each step
greedy = 0.9 # the greedy factor
寻宝过程探险者只有上下左右四个方向,ACTION完全是自由发挥的。
建立Q表
def build_table(nStates, actions):
'''
DataFrame是一个表格型的数据结构,由行和列组成,分别有行索引和列索引,且每列可以是不同类型的值。
创建DataFrame对象
创建的时候,可以通过参数index和columns分别指定行索引和列索引
1.传入一个numpy的多维数组对象
2.传入一个字典内部包含列表,字典内的列表是等长的, 字典的key默认为列索引
'''
table = pd.DataFrame(
np.zeros((nStates*nStates, len(actions))),
columns = actions)
# q_table:
"""
left right top bottom
0 0.0 0.0 0.0 0.0
1 0.0 0.0 0.0 0.0
2 0.0 0.0 0.0 0.0
3 0.0 0.0 0.0 0.0
4 0.0 0.0 0.0 0.0
5 0.0 0.0 0.0 0.0
6 0.0 0.0 0.0 0.0
7 0.0 0.0 0.0 0.0
8 0.0 0.0 0.0 0.0
9 0.0 0.0 0.0 0.0
"""
return table
根据状态和Q-table选择下一步的action
有随机的概率不使用argmax来选择行为,给sample增加随机性
def chooseActions(state, table):
'''
loc通过行和列 的名字获取值
iloc通过下标获取值
'''
#print(state)
actions = table.iloc[state,:] #choose corresponding values of actions about this state
if(np.random.uniform() >= greedy) or (actions.all() == 0):#not greedy or all the actions have not been sampled
actName = np.random.choice(ACTION);
else:
actName = actions.argmax()
return actName
与环境的交互
actor每采取一个行动,我们更新他在地图上相应的位置,给他相应的reward,这里的reward是我调整过的,actor存活时间越久,距离中心位置越远,得到的惩罚越多,只有到达终点才能获得positive reward。当然这样的reward不一定是最优的,不同的游戏reward各种各样,需要慢慢尝试
#Interaction with environment and get feedback
'''
S: current state
A: name of action
'''
def feedback(S, A):
# This is how agent will interact with the environment
if A == "right": #move right
if S == POS - 1:#POS is the destination, the center of the two-dimensional world
S_ = "terminal" #only when we approach the treasure can we get the reward, 24
R = 10
elif S % N_states == N_states - 1: #on the right border
S_ = S; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
else:
S_ = S + 1; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
elif A == "left":
if S == POS + 1:#12 is the destination, the center of the two-dimensional world
S_ = "terminal" #only when we approach the treasure can we get the reward, 24
R = 10
elif S % N_states == 0: #on the left border
S_ = S; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
else:
S_ = S - 1; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
elif A == "top":
if S == POS - N_states:#12 is the destination, the center of the two-dimensional world
S_ = "terminal" #only when we approach the treasure can we get the reward, 24
R = 10
elif int(S/N_states) == N_states - 1: # on the top border
S_ = S; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
else:
S_ = S + N_states; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
elif A == "bottom":
if S == POS + N_states:#12 is the destination, the center of the two-dimensional world
S_ = "terminal" #only when we approach the treasure can we get the reward, 24
R = 10
elif int(S/N_states) == 0:#on the bottom
S_ = S; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
else:
S_ = S - N_states; R = -abs(S%N_states-POS%N_states)-abs(int(S/N_states)-int(POS/N_states))
return S_, R
更新环境
打印地图信息,actor所在位置,已经经过了多少次训练等必要的信息。
'''
S:state int: the position of explorer
episode: episode
stepCnt: total step in this episode
'''
def updateEnvironment(S, episode, stepCnt):
#this is how the environment be updated
envList =[ ['-' for i in range(N_states)] for j in range(N_states) ]
envList[int(N_states/2)][int(N_states/2)] = 'T'
#envList = ['-']*(N_states) + ['T'] # '---------T' our environment
if S == "terminal":
interaction = 'Episode %s: total_steps = %s' % (episode+1, stepCnt)
print('\r{}'.format(interaction), end='')
time.sleep(2)
print('\r \n', end='')
else:
envList[int(S/N_states)][S%N_states] = 'o'
#interaction = ''.join(envList)
for i in range(N_states):
print(envList[i])
print("\n")
time.sleep(interval)
Q-table的更新
完全依照伪代码进行的参数更新
#most important part, the update of Q-table
def rl():
x = []
cnts = []
QTable = build_table(N_states, ACTION);
for episode in range(episodes):
x.append(episode)
stepCnt = 0
S = random.randint(0,N_states*N_states-1) #the initial position
terminated = False #whether terminated
updateEnvironment(S, episode, stepCnt)
while not terminated:
a = chooseActions(S, QTable);#determine the next action according to the state and Q-table
S_, R = feedback(S, a) #play the game, getting reward and the next state
QPredict = QTable.loc[S, a] # the predict value of state-action pair
if S_ != "terminal":
QTarget = R + gamma * QTable.iloc[S_,:].max() #actual value of state-action pair
else:
QTarget = R #terminated
terminated = True # terminate this episode
#updating Q-table, execting tempory difference method
QTable.loc[S,a] += alpha * (QTarget - QPredict)
S = S_
stepCnt += 1
updateEnvironment(S, episode, stepCnt)
cnts.append(stepCnt)
plt.plot(x, cnts, color = "grey", linestyle = '--', marker = '^')
plt.xlabel("episode")
plt.ylabel("Total steps")
return QTable
主函数
#begin to play
if __name__ == "__main__":
q_table = rl()
print('\r\nQ-table:\n')
print(q_table)