算法伪代码:
使用maze环境:maze_env 的代码见这里
import numpy as np
import pandas as pd
from maze_env import Maze
import queue
class Q(object):
def __init__(self, action_space):
self.nA = action_space
self.actions = list(range(action_space))
self.q_table = pd.DataFrame(columns=self.actions)
self.init_Q()
def init_Q(self):
for x in range(5, 165, 40):
for y in range(5, 165, 40):
if x == 45 and y == 85:
s = 'terminal'
elif x == 85 and y == 45:
s = 'terminal'
elif x == 85 and y == 85:
s = 'terminal'
else:
s = [x+0.0, y+0.0, x + 30.0, y + 30.0]
s = str(s)
if s not in self.q_table.index:
self.q_table = self.q_table.append(
pd.Series([0] * len(self.actions),
index=self.q_table.columns,
name=s)
)
def target_policy(self, s):
# target_policy is the greedy policy
# self.check_state_exist(s)
A = self.target_policy_probs(s)
return np.random.choice(range(self.nA), p=A)
def target_policy_probs(self, s, epsilon=.3):
A = np.ones(self.nA, dtype=float) * epsilon / self.nA
best_action = np.argmax(self.q_table.loc[s, :])
A[best_action] += (1.0 - epsilon)
return A
class Model(object):
def __init__(self):
self.model = dict()
def store(self, s, a, r, s_):
self.model[s, a] = [r, s_]
if __name__ == '__main__':
env = Maze()
action_space = env.n_actions
RL = Q(action_space)
model = Model()
PQueue = queue.Queue()
gamma = 0.9
alpha = 0.01
theta = 0.5
for episode in range(100):
state = env.reset()
while True:
env.render()
action = RL.target_policy(str(state))
state_, reward, done = env.step(action)
model.store(str(state), action, reward, str(state_))
P = abs(reward + gamma * np.max(RL.q_table.loc[str(state_), :])
- RL.q_table.loc[str(state), action])
if P > theta:
PQueue.put([str(state), action])
while not PQueue.empty():
S_A = PQueue.get()
S = S_A[0]
A = S_A[1]
R, S_ = model.model[S, A]
G = R + gamma * np.max(RL.q_table.loc[S_, :])
RL.q_table.loc[S, A] += alpha * (G - RL.q_table.loc[S, A])
values = model.model.values()
# S = ‘terminal'时 已经结束 这里不需要考虑
if [0, S] in model.model.values():
for keys in range(0, len(model.model)):
if list(model.model.values())[keys] == [0, S]:
_s_a = list(model.model.keys())[keys]
_s_a = list(_s_a)
_s = _s_a[0]
_a = _s_a[1]
_r = 0
P = abs(_r + gamma * np.max(RL.q_table.loc[str(S), :])
- RL.q_table.loc[_s, _a])
if P > theta:
PQueue.put([_s, _a])
if done:
break
state = state_
print('game over')
env.destroy()