SARASA算法:
SARSA算法遵从了交互序列,根据下一步的真实行动进行价值估计;
另一种TD法,Q-Learning算法没有遵循交互序列,而是在下一时刻选择了使价值最大的行动。
这两种算法代表了两种策略评估的方式,分别是On-Policy和Off-Policy。
On-Policy对值函数的更新是完全依据交互序列进行的,在计算时认为价值可以直接利用采样的序列估计得到。
Off-Policy并不完全遵循交互序列,而是选择来自其他策略的交互序列的子部分替换了原本的交互序列。
从算法思想上来说,Q-Learning更复杂,它结合了子部分的最优价值,更像是结合了价值迭代的更新算法,希望每一次都使用前面迭代积累的最优结果进行更新。
代码如下:
class QLearning(object):
def __init__(self, epsilon=0.0):
self.epsilon = epsilon
def q_learn_eval(self, agent, env):
state = env.reset()
pre_state = -1
prev_act = -1
while True:
act = agent.play(state, self.epsilon)
next_state, reward, terminate, _ = env.step(act)
if prev_act != -1:
if terminate:
return_val = reward
else:
return_val = reward + agent.gamma*np.max(agent.value_q[state,:])
agent.value_n[prev_state][prev_act] += 1
agent.value_q[prev_state][prev_act] += (
(return_val - agent.value_q[prev_state][prev_act])/
agent.value_n[prev_state][prev_act]
)
prev_act = act
prev_state = state
state = next_state
if terminate:
break
def policy_improve(self, agent):
new_policy = np.zeros_like(agent.pi)
for i in range(1, agent.s_len):
new_policy[i] = np.argmax(agent.value_q[i, :])
if np.all(np.equal(new_policy, agent.pi)):
return False
else:
agent.pi = new_policy
return True
def q_learning(self, agent, env):
for i in range(10):
for j in range(3000):
self.q_learn_eval(agent, env)
self.policy_improve(agent)
def q_learning_demo():
env = SnakeEnv(10, [3,6])
np.random.seed(101)
agent3 = ModelFreeAgent(env)
ql = QLearning(0.5)
with timer('Timer Q Learning Iter'):
ql.q_learning(agent3, env)
print('return_pi={}'.format(eval_game(env,agent3)))
print(agent3.pi)
q_learning_demo()