依赖
pip install numpy
pip install pandas
pip install gym
运行图
代码
import os
import time
from itertools import count
import numpy as np
import pandas as pd
import gym
# import gymnasium as gym
NUM_EPISODES = 10
MAX_EPSILON = 1
MIN_EPSILON = 0.05
EPSILON_DECAY_RATE = 0.005
class QLearning(object):
def __init__(self, env) -> None:
self.epsilon = MAX_EPSILON
self.alpha = 0.5 # learning rate
self.gamma = 0.95
self.episodes = NUM_EPISODES
self.env = env
self.q_table_csv = './q_table_{}x{}.csv'.format(self.env.observation_space.n, self.env.action_space.n)
self.q_table = pd.DataFrame(
np.zeros((self.env.observation_space.n, self.env.action_space.n)), \
index=range(0, self.env.observation_space.n), \
columns=range(0, self.env.action_space.n) \
)
print('qtable\n',self.q_table)
def select_action(self, state, greedy=False):
e = np.random.uniform()
action = None
if (e < self.epsilon or (self.q_table.iloc[state] == 0).all()) and not greedy:
action = self.env.action_space.sample()
else:
action = self.q_table.iloc[state].idxmax()
return action
def update_q_table(self, state, action, reward, next_state):
q = self.q_table.iloc[state][action]
q_new = q + self.alpha * (reward + self.gamma * (self.q_table.iloc[next_state].max()) - q)
self.q_table.iloc[state][action] = q_new
def train(self):
print('train')
for i in range(self.episodes):
observation, info = self.env.reset()
for t in count():
action = self.select_action(observation)
observation_new, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if done and reward == 0:
reward = -1
self.update_q_table(observation, action, reward, observation_new)
observation = observation_new
if done:
self.epsilon = MIN_EPSILON + (MAX_EPSILON - MIN_EPSILON) * np.exp(-EPSILON_DECAY_RATE * i)
print(i, t, observation, action, observation_new, reward, terminated, truncated, self.epsilon, info)
print('latest q_table:\n',qlearn.q_table)
break
# save
self.q_table.to_csv(self.q_table_csv, index=False)
def play(self):
print('play', self.q_table_csv)
if os.path.exists(self.q_table_csv):
dtype = dict(zip(np.array([str(x) for x in np.arange(0,self.env.action_space.n)]), np.array(['float64'] * self.env.action_space.n)))
self.q_table = pd.read_csv(self.q_table_csv, header=0, dtype=dtype)
print('read q_table\n', self.q_table)
observation, info = self.env.reset()
done = False
while not done:
action = self.select_action(observation, True)
observation_new, reward, terminated, truncated, info = self.env.step(int(action))
done = terminated or truncated
observation = observation_new
time.sleep(0.5)
if __name__ == '__main__':
env = gym.make('FrozenLake-v1', desc=None, map_name='4x4', is_slippery=False, render_mode="human")
qlearn = QLearning(env)
qlearn.train()
qlearn.play()
print('latest q_table:\n',qlearn.q_table)
参考