CliffWalking 悬崖行走问题是强化学习中一个经典的入门级问题。它描述了一个智能体从起点 S 到达目标 G 的网格世界环境,其中存在悬崖会让智能体掉落并重置到起点。智能体的目标是通过学习最优策略,尽可能减少到达目标所需的步数。
环境
CliffWalking 环境通常被表示为一个网格世界,其中包含以下元素:
- 起点 (S): 智能体开始的位置。
- 目标 (G): 智能体想要到达的位置。
- 悬崖 (C): 智能体掉落后会重置到起点的位置。
- 普通格子: 智能体移动时会获得 -1 的奖励。
环境状态数
48
智能体可以执行以下四个动作,action 索引对应关系
- 上 0
- 右 1
- 下 2
- 左 3
Q-Learning 算法
Q-Learning 是一种常用的强化学习算法,用于解决 CliffWalking 问题。它使用 Q 表来存储每个状态-动作对的期望奖励。智能体会根据 Q 表来选择动作,并不断更新 Q 表以提高策略的性能。
算法流程
- 初始化 Q 表:将每个状态-动作对的 Q 值初始化为 0。
- 选择动作:根据当前状态和 ε-贪婪策略选择一个动作。
- 执行动作:与环境交互,执行选定的动作并获得奖励和下一个状态。
- 更新 Q 表:根据 Bellman 方程更新 Q 表。
- 重复步骤 2-4,直到达到终止条件。
import gym
import numpy as np
import matplotlib.pyplot as plt
# Create the CliffWalking environment
env = gym.make('CliffWalking-v0')
# Initialize Q-table with zeros
Q = np.zeros((env.observation_space.n, env.action_space.n))
# Set hyperparameters
alpha = 0.1 # Learning rate
gamma = 0.99 # Discount factor
epsilon = 0.1 # Exploration rate
num_episodes = 1000 # Number of episodes
# Function for choosing an action using epsilon-greedy policy
def epsilon_greedy(state, epsilon):
if np.random.rand() < epsilon:
return np.random.choice(env.action_space.n)
else:
max1 = np.max(Q[state])
return np.random.choice(np.argwhere(Q[state] == max1).ravel())
# List to store the total reward per episode
rewards = []
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
state = state[0]
stepnum = 0
while True:
stepnum +=1
action = epsilon_greedy(state, epsilon)
next_state, reward, done,_, _ = env.step(action)
# Q-learning update
best_next_action = np.argmax(Q[next_state])
if done:
td_target = reward
else:
td_target = reward + gamma * Q[next_state, best_next_action]
Q[state, action] += alpha * (td_target - Q[state, action])
state = next_state
total_reward += reward
if done:
rewards.append(total_reward)
break
env.close()
#check
env = gym.make('CliffWalking-v0', render_mode="human")
state = env.reset()
state = state[0]
while True:
action = epsilon_greedy(state, -1)
next_state, reward, done, _, _ = env.step(action)
state = next_state
if done:
break
# Plotting the total rewards per episode
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Total Rewards per Episode in CliffWalking')
plt.show()
# Example of the Q-table after training
print("Q-table after training:")
print(Q)
运行结果
Q-table after training:
[[ -11.01143656 -10.97200536 -11.07580864 -11.00034742]
[ -10.53532711 -10.53837496 -10.54711183 -10.56661425]
[ -9.93283073 -9.93636179 -9.96014774 -9.94421654]
[ -9.32524557 -9.25588848 -9.26132718 -9.362967 ]
[ -8.58502237 -8.51740036 -8.60589717 -8.69268765]
[ -7.75081398 -7.751814 -7.75891081 -7.92356173]
[ -7.0090195 -6.98093189 -7.04866952 -7.22981792]
[ -6.26849349 -6.1927947 -6.23173222 -6.36483544]
[ -5.41760614 -5.37364903 -5.40579277 -5.58507139]
[ -4.58345926 -4.56119197 -4.58452105 -4.65389169]
[ -3.81090246 -3.75268117 -3.75552881 -3.94623004]
[ -2.95189925 -2.94536982 -2.93811751 -3.07875708]
[ -11.39831906 -11.36886935 -11.44991149 -11.3913221 ]
[ -10.86782145 -10.87912745 -10.86594645 -11.00784236]
[ -10.16913152 -10.17961094 -10.18287655 -10.23801665]
[ -9.38427824 -9.38151279 -9.3878743 -9.55248997]
[ -8.59546788 -8.52601812 -8.53113836 -8.68376172]
[ -7.71283138 -7.65328652 -7.65204547 -7.76012005]
[ -6.86039027 -6.75206272 -6.75024712 -7.01878308]
[ -6.17358504 -5.83019078 -5.83165012 -6.89112464]
[ -5.01377416 -4.88875438 -4.88957237 -5.19514171]
[ -4.25528206 -3.93563165 -3.9355763 -4.26577178]
[ -3.33250507 -2.96882087 -2.9687877 -2.98457916]
[ -2.82388042 -2.31006529 -1.98999878 -2.29019384]
[ -12.00989998 -11.36151283 -12.90964752 -12.22182905]
[ -11.5315813 -10.46617457 -111.23155507 -12.13533698]
[ -10.88382501 -9.5617925 -109.98598868 -11.14939789]
[ -10.01862238 -8.64827525 -107.99851971 -10.30399957]
[ -9.15425612 -7.72553056 -109.62248996 -9.45146278]
[ -8.13263811 -6.79346521 -105.38023923 -8.4867883 ]
[ -7.38723868 -5.85198506 -110.32136827 -7.53211469]
[ -6.65884927 -4.90099501 -100.65150677 -6.55379866]
[ -5.40320615 -3.940399 -105.43007104 -5.68739168]
[ -4.71149533 -2.9701 -106.6354787 -4.57607392]
[ -3.69194928 -1.99 -105.15939645 -3.81256339]
[ -2.87617683 -1.95203942 -1. -2.80614277]
[ -12.2478977 -108.00745676 -13.09134424 -13.03572322]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]