off-policy n-step Q(σ) 的伪代码:
maze_env 的代码见这里
import numpy as np
import pandas as pd
from maze_env import Maze
import random
class QSigma(object):
# Off-policy n-step Q(σ)
def __init__(self, action_space):
self.nA = action_space
self.actions = list(range(action_space))
self.q_table = pd.DataFrame(columns=self.actions)
def check_state_exist(self, s):
if s not in self.q_table.index:
self.q_table = self.q_table.append(
pd.Series([0]*len(self.actions),
index=self.q_table.columns,
name=s)
)
def target_policy(self, s):
# target_policy is the greedy policy
self.check_state_exist(s)
A = self.target_policy_probs(s)
return np.random.choice(range(self.nA), p=A)
def target_policy_probs(self, s, epsilon=.1):
A = np.ones(self.nA, dtype=float) * epsilon / self.nA
best_action = np.argmax(self.q_table.loc[s, :])
A[best_action] += (1.0 - epsilon)
return A
def behaviour_policy(self, s):
# behaviour policy is the epsilon-greedy
self.check_state_exist(s)
A = self.behaviour_policy_probs(s)
return np.random.choice(range(self.nA), p=A)
def behaviour_policy_probs(self, s, epsilon=.3):
A = np.ones(self.nA, dtype=float) * epsilon / self.nA
best_action = np.argmax(self.q_table.loc[s, :])
A[best_action] += (1.0 - epsilon)
return A
if __name__ == '__main__':
env = Maze()
action_space = env.n_actions
RL = QSigma(action_space)
n = 3
gamma = 0.9
alpha = 0.01
for episode in range(100):
buffer_s = []
buffer_a = []
buffer_r = []
buffer_Q = []
TD_error = []
pi = []
rou = []
buffer_sigma = []
state = env.reset()
action = RL.behaviour_policy(str(state))
buffer_s.append(str(state))
buffer_a.append(action)
buffer_Q.append(RL.q_table.loc[str(state), action]) # Store Q(S0,A0) as Q0
T = 10000
t = 0
while True:
if t < T:
env.render()
state_, reward, done = env.step(action)
buffer_s.append(str(state_))
# buffer_r.append(reward)
if state_ == 'terminal':
T = t + 1
TD_error.append(reward - buffer_Q[t])
else:
RL.check_state_exist(str(state_))
action_ = RL.behaviour_policy(str(state_))
buffer_a.append(action_)
sigma = random.choice([0, 1])
buffer_sigma.append(sigma)
buffer_Q.append(RL.q_table.loc[str(state_), action_])
temp = 0
for a in range(4):
temp += RL.target_policy_probs(str(state_))[a] * RL.q_table.loc[str(state_), a]
TD_error.append(reward + gamma * buffer_sigma[t] * buffer_Q[t] +
gamma * (1-buffer_sigma[t])*temp - buffer_Q[t])
pi.append(RL.target_policy_probs(str(state_))[action_])
rou.append(RL.target_policy_probs(str(state_))[action_] /
RL.behaviour_policy_probs(str(state_))[action_])
action = action_
tao = t - n + 1
if tao >= 0:
rho = 1
E = 1
G = buffer_Q[tao]
for k in range(tao, min(tao + n, T)):
G += E * TD_error[k]
E = gamma * E * ((1 - rou[k-1]) * pi[k - 1] + rou[k-1])
rho = rho * (1 - buffer_sigma[k-2] + buffer_sigma[k-2] * rou[k-2])
RL.q_table.loc[buffer_s[tao], buffer_a[tao]] += \
alpha * rho * (G - RL.q_table.loc[buffer_s[tao], buffer_a[tao]])
if tao == T-1:
break
t += 1
print('game over')
env.destroy()