q_table训练过程
import os
import time
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import clear_output
from IPython import display
from gymnasium import wrappers
os.environ["SDL_VIDEODRIVER"] = 'dummy'
%matplotlib inline
# 绘制价值table
def draw_grid_with_numbers(mat, ax):
# fig, ax = plt.subplots()
ax.clear()
# 绘制 4x4 的方格
for i in range(4):
for j in range(4):
rect = plt.Rectangle((j, 3-i), 1, 1, fill=False, edgecolor='black')
ax.add_patch(rect)
# 计算方格中心坐标
center_x = j + 0.5
center_y = (3-i) + 0.5
maxval = np.max(mat[i*4+j]); maxcnt = sum(mat[i*4+j] == maxval)
maxidx = np.argmax(mat[i*4+j]) if maxcnt==1 else -1
# 上方数字
ax.text(center_x, center_y + 0.3, f"{mat[i*4+j, 3]}", ha='center', va='center', c='r' if maxidx==3 else 'g', fontsize=10)
# 下方数字
ax.text(center_x, center_y - 0.3, f"{mat[i*4+j, 1]}", ha='center', va='center', c='r' if maxidx==1 else 'g', fontsize=10)
# 左方数字
ax.text(center_x - 0.3, center_y, f"{mat[i*4+j, 0]}", ha='center', va='center', c='r' if maxidx==0 else 'g', fontsize=10)
# 右方数字
ax.text(center_x + 0.3, center_y, f"{mat[i*4+j, 2]}", ha='center', va='center', c='r' if maxidx==2 else 'g', fontsize=10)
# 设置坐标轴范围和刻度
ax.set_xlim(0, 4)
ax.set_ylim(0, 4)
ax.set_xticks(np.arange(0, 5, 1))
ax.set_yticks(np.arange(0, 5, 1))
ax.set_aspect('equal', adjustable='box')
return ax
# 绘制
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 8))
# 创建环境
env = gym.make('FrozenLake-v1', render_mode='rgb_array', is_slippery=False)
env.reset()
# q_table: 价值表格
q_table = np.zeros([env.observation_space.n, env.action_space.n])
learning_rate = 0.8
discount_factor = 0.95
num_episodes = 20; # 这个游戏玩多少轮
for episode in range(num_episodes):
state, info = env.reset()
done = False
while not done:
if np.random.uniform(0, 1) < 0.1:
action = env.action_space.sample()
else:
action = np.argmax(q_table[state, :])
next_state, reward, done, tru, info = env.step(action)
if done and reward!=1: # 如果掉坑里,价值-10
reward = -10
elif done and reward == 1: # 如果结束且到达目标,价值10
reward = 10
if next_state == state: # 如果下一个状态是自己,那么-1
reward = -1
if reward == 0: # 如果价值为0,说明在路上,价值-1
reward = -1
q_table[state, action] = (
q_table[state, action] +
learning_rate * (reward + discount_factor* np.max(q_table[next_state,:]) - q_table[state, action])
)
state = next_state
title = f"episode: {episode}" + '\n' + f"state: {state}, action: {action}, reward: {reward}"
ax1.imshow(env.render()); ax1.set_title(title)
ax2 = draw_grid_with_numbers(np.round(q_table, 3), ax2);
display.display(plt.gcf())
# display.display(q_table)
display.clear_output(wait=True)
# time.sleep(0.001)
if done:
break
env.close()
使用q_table的过程
按照q_table走到终点的过程
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 8))
env = gym.make('FrozenLake-v1', render_mode='rgb_array', is_slippery=False)
env.reset()
num_episodes = 1;
actions = []
for episode in range(num_episodes):
state, info = env.reset()
done = False
for step in range(10):
# while not done:
action = np.argmax(q_table[state, :]); actions.append(action)
next_state, reward, done, tru, info = env.step(action)
state = next_state
title = f'next_state: {next_state}, step: {step}' + '\n' + f"state: {state}, action: {action}"
ax1.imshow(env.render()); ax1.set_title(title)
display.display(plt.gcf())
ax2 = draw_grid_with_numbers(np.round(q_table, 2), ax2)
# display.display(q_table)
display.clear_output(wait=True)
time.sleep(1)
if done:
break
env.close()