在本节代码中,莫烦大神用Q_Learning算法实现一个探索者走迷宫的游戏,核心代码包括run_this.py和RL_brain.py.其中,run_this.py对应迭代更新部分。
from maze_env import Maze
from RL_brain import QLearningTable
def update():
for episode in range(100): #玩100轮游戏
# initial observation
observation = env.reset() #游戏开始,初始化游戏
while True:
# fresh env
env.render()
# RL choose action based on observation
action = RL.choose_action(str(observation))
# RL take action and get next observation and reward
observation_, reward, done = env.step(action)
# RL learn from this transition
RL.learn(str(observation), action, reward, str(observation_))
# swap observation
observation = observation_
# break while loop when end of this episode
if done:
break
# end of game
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = QLearningTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()
RL_brain.py实现具体的Q_Learning算法
import numpy as np
import pandas as pd
class QLearningTable:
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.actions = actions # a list
self.lr = learning_rate #学习率
self.gamma = reward_decay #奖励衰减
self.epsilon = e_greedy #贪婪度
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
#初始化Q表,此时DataFrame中行索引为空
def choose_action(self, observation):
self.check_state_exist(observation)
# action selection
if np.random.uniform() < self.epsilon:
# choose best action
state_action = self.q_table.loc[observation, :] #在该状态下所有动作对应的Q值
# some actions may have the same value, randomly choose on in these actions
action = np.random.choice(state_action[state_action ==
np.max(state_action)].index)
else:
# choose random action
action = np.random.choice(self.actions)
return action
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a] #q估计
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max() # next state is not terminal,q现实
else:
q_target = r # next state is terminal
#学习率*(Q现实-Q估计),判断误差传递回去
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state,
)
)