Sarsa更新速度慢,但较为安全;QLearning速度快,但容易掉进悬崖
简单写一下悬崖寻路的代码:
Sarsa:
import gym
import numpy as np
import matplotlib.pyplot as plt
def epsilon_greedy(state, epsilon):
# 探索
if np.random.uniform(0, 1) < epsilon:
return env.action_space.sample()
# 利用
else:
return np.argmax(q_table[state, :])
def Sarsa():
for i in range(num_epsides):
# print("当前第 %s 次" % i)
state = env.reset()
epsilon = np.linspace(0.9, 0.1, num=num_epsides)[i]
action = epsilon_greedy(state, epsilon)
r = 0
while True:
next_state, reward, done, _ = env.step(action)
# print("当前第 %s 回合" % i)
# print("从第 %s 步移动到 第 %s 步" % (state, next_state))
next_action = epsilon_greedy(next_state, epsilon)
# print("在第 %s 步选择 %s 动作" % (next_state, next_action))
q_table[state, action] += alpha * (reward + gamma * q_table[next_state, next_action] - q_table[state, action])
# 更新状态和动作
state = next_state
action = next_action
r += reward
if done:
print("第 %s 回合获得总奖励为 %s" % (i, r))
break
rewards.append(r)
def printBestRoute():
best_route = []
state = env.reset()
best_route.append(state)
while True:
action = np.argmax(q_table[state, :])
print(action)
next_state, _, done, _ = env.step(action)
state = next_state
best_route.append(state)
if done:
break
def drawRewards():
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(16, 7))
plt.plot(rewards, 'b-', label='Sarsa')
# plt.plot(rewards, 'r-', label='Q-learning')
plt.legend(loc='best', fontsize=15) # 设置图例位置
plt.tick_params(labelsize=15) # 刻度线设置
plt.xlabel('迭代次数', fontsize=15)
plt.ylabel('累积奖励', fontsize=15)
plt.title('sarsa', fontsize=20)
plt.show()
if __name__ == '__main__':
env = gym.make('CliffWalking-v0')
rewards = []
alpha = 0.8
gamma = 0.95
q_table = np.zeros([env.observation_space.n, env.action_space.n])
num_epsides = 600
Sarsa()
# printBestRoute()
drawRewards()
print(rewards)
QLearning:
import gym
import numpy as np
import matplotlib.pyplot as plt
def epsilon_greedy(state, epsilon):
if np.random.uniform(0, 1) < epsilon:
return env.action_space.sample()
else:
return np.argmax(q_table[state, :])
def q_learning():
for i in range(1, num_epsides):
state = env.reset()
epsilon = np.linspace(0.9, 0.1, num_epsides)[i]
epsides_reward = 0
while True:
action = epsilon_greedy(state, epsilon)
next_state, reward, done, _ = env.step(action)
print("当前第 %s 回合" % i)
print("从第 %s 步移动到 第 %s 步" % (state, next_state))
print("在第 %s 步选择 %s 动作" % (next_state, action))
q_table[state, action] += alpha * (reward + gamma * max(q_table[next_state]) - q_table[state, action])
state = next_state
epsides_reward += reward
if done:
print("第 %s 回合获得总奖励为 %s" % (i, epsides_reward))
break
rewards.append(epsides_reward)
def printBestRoute():
best_route = []
state = env.reset()
best_route.append(state)
while True:
action = np.argmax(q_table[state, :])
print(action)
next_state, _, done, _ = env.step(action)
state = next_state
best_route.append(state)
if done:
break
def drawRewards():
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(16, 7))
# plt.plot(rewards, 'b-', label='Sarsa')
plt.plot(rewards, 'r-', label='Q-learning')
plt.legend(loc='best', fontsize=15) # 设置图例位置
plt.tick_params(labelsize=15) # 刻度线设置
plt.xlabel('迭代次数', fontsize=15)
plt.ylabel('累积奖励', fontsize=15)
plt.title('Q-learning', fontsize=20)
plt.show()
if __name__ == '__main__':
env = gym.make('CliffWalking-v0', render_mode='human')
np.random.seed(0)
rewards = []
alpha = 0.8
gamma = 0.95
q_table = np.zeros([env.observation_space.n, env.action_space.n])
num_epsides = 600
q_learning()
printBestRoute()
drawRewards()
print(rewards)
从对比效果来看,QLearning速度快于Sarsa