1.qlearning和sarsa 区别
(1)qlearning——off-policy:离线
(2)sarsa——on-policy:在线
2.程序
学习模式不同
class SarsaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
更新模式不同
def update():
for episode in range(100):
# initial observation
observation = env.reset()
# action
action = RL.choose_action(str(observation))
while True:
# fresh env
env.render()
# 获取奖励和观测值和是否结束
observation_, reward, done = env.step(action)
# 下一步action
action_ = RL.choose_action(str(observation_))
# Sarsa
RL.learn(str(observation), action, reward, str(observation_), action_)
# 更新 observation and action
observation = observation_
action = action_
# break while loop when end of this episode
if done:
break
# end of game
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = SarsaTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()