解决强化学习的训练问题有很多种方法,本节用时间差分方法Sarsa来对一个简单的迷宫问题进行求解。
迷宫问题的地图简单描述如下。
同策略的Sarsa方法更新动作值函数更新公式如下:
简单的说明一下,就是通过概率模拟状态s的时候,选择执行动作a,到达了状态s’,再利用状态s’处的Q(s’,a’)来更新Q(s, a)的值,但是因为是模拟,所以不能直接用Q(s,a) = r + yQ(s’,a’)来直接计算, 通过 r + yQ(s’,a’) - Q(s,a),会得到当前值函数Q(s,a)与最新模拟的值函数r + yQ(s’,a’)的偏差值,再将其一定比例的加到原来的Q(s, a)上,这个一定的比列你可以认为是传统的学习率。
代码部分
import numpy as np
import random
from gym import spaces
import gym
from gym.envs.classic_control import rendering
#模拟环境类
class GridWorldEnv(gym.Env):
#相关的全局配置
metadata = {
'render.modes':['human', 'rgb_array'],
'video.frames_per_second': 2
}
def __init__(self):
self.states = [i for i in range(1, 26)] #初始化状态
self.terminate_states = [3, 4, 5, 11, 12, 19, 24, 15] #终结态
self.actions = ['up', 'down', 'left', 'right'] #动作空间
self.value_of_state = dict() #状态的值空间
for state in self.states:
self.value_of_state[state] = 0.0
for state in self.terminate_states: #先将所有陷阱的值函数初始化为-1.0
self.value_of_state[state] = -1.0
self.value_of_state[15] = 1.0 #黄金的位置值函数初始化为 1
self.initStateAction() #初始化每个状态的可行动作空间
self.initStatePolicyAction() #随机初始化当前策略
self.initQ_s_a()
self.gamma = 0.8 #计算值函数用的折扣因子
self.alpha = 0.1 #学习率
self.viewer = None #视图对象
self.current_state = None #当前状态
return
def translateStateToRowCol(self, state):
"""
将状态转化为行列坐标返回
"""
row = (state - 1) // 5
col = (state - 1) % 5
return row, col
def translateRowColToState(self, row, col):
"""
将行列坐标转化为状态值
"""
return row * 5 + col + 1
def actionRowCol(self, row, col, action):
"""
对行列坐标执行动作action并返回坐标
"""
if action == "up":
row = row - 1
if action == "down":
row = row + 1
if action == "left":
col = col - 1
if action == "right":
col = col + 1
return row, col
def canUp(self, row, col):
row = row - 1
return 0 <= row <= 4
def canDown(self, row, col):
row = row + 1
return 0 <= row <= 4
def canLeft(self, row, col):
col = col - 1
return 0 <= col <= 4
def canRight(self, row, col):
col = col + 1
return 0 <= col <= 4
def initStateAction(self):
"""
初始化每个状态可行动作空间,并且初始化
"""
self.states_actions = dict()
for state in self.states:
self.states_actions[state] = []
if state in self.terminate_states:
continue
row, col = self.translateStateToRowCol(state)
if self.canUp(row, col):
self.states_actions[state].append("up")
if self.canDown(row, col):
self.states_actions[state].append("down")
if self.canLeft(row, col):
self.states_actions[state].append('left')
if self.canRight(row, col):
self.states_actions[state].append('right')
return
def initQ_s_a(self):
"""
初始化Q值函数
"""
self.Q_s_a = dict()
for state in self.states:
if state in self.terminate_states:
continue
for action in self.states_actions[state]:
self.Q_s_a["%d_%s" % (state, action)] = 0.0 #初始化所有的行为值函数
def epsilon_greedy(self, state, epsilon):
"""
概率模拟在状态s,如何通过概率模拟得到下一步动作
"""
action_size = len(self.states_actions[state])
max_value_action = self.states_actions[state][0]
for action in self.states_actions[state]:
if self.Q_s_a["%d_%s" % (state, action)] > self.Q_s_a["%d_%s" % (state, max_value_action)]:
max_value_action = action
prob_list = [0.0 for _ in range(0, action_size)]
for i in range(0, action_size):
if self.states_actions[state][i] == max_value_action:
prob_list[i] = 1 - epsilon + epsilon / action_size
else:
prob_list[i] = epsilon / action_size
r = random.random()
s = 0.0
for i in range(0, action_size):
s += prob_list[i]
if s >= r:
return self.states_actions[state][i]
return self.states_actions[state][-1]
def greedy_action(self, state):
"""
获取最优策略
"""
action_size = len(self.states_actions[state])
max_value_action = self.states_actions[state][0]
for action in self.states_actions[state]:
if self.Q_s_a["%d_%s" % (state, action)] > self.Q_s_a["%d_%s" % (state, max_value_action)]:
max_value_action = action
return max_value_action
def initStatePolicyAction(self):
"""
初始化每个状态的当前策略动作
"""
self.states_policy_action = dict()
for state in self.states:
if state in self.terminate_states:
self.states_policy_action[state] = None
else:
self.states_policy_action[state] = random.sample(self.states_actions[state], 1)[0]
return
def seed(self, seed = None):
random.seed(seed)
return [seed]
def reset(self):
"""
重置原始状态
"""
self.current_state = random.sample(self.states, 1)[0]
def step(self, action):
"""
动作迭代函数
"""
cur_state = self.current_state
if cur_state in self.terminate_states:
return cur_state, 0, True, {}
row, col = self.translateStateToRowCol(cur_state)
n_row, n_col = self.actionRowCol(row, col, action)
next_state = self.translateRowColToState(n_row, n_col)
self.current_state = next_state
if next_state in self.terminate_states:
return next_state, 0, True, {}
else:
return next_state, 0, False, {}
def policy_evaluate_sarsa(self):
"""
遍历状态空间,对策略进行评估和改善
"""
for state in self.states:
if state in self.terminate_states:
continue
for action in self.states_actions[state]:
self.current_state = state
next_state, reward, isTerminate, info = self.step(action)
if isTerminate is True:
s_a = "%d_%s" % (state, action)
self.Q_s_a[s_a] = self.Q_s_a[s_a] + self.alpha * (reward + self.gamma * self.value_of_state[next_state] - self.Q_s_a[s_a])
else:
s_a = "%d_%s" % (state, action)
n_action = self.epsilon_greedy(next_state, 0.3)
n_s_a = "%d_%s" % (next_state, n_action)
self.Q_s_a[s_a] = self.Q_s_a[s_a] + self.alpha * (reward + self.gamma * self.Q_s_a[n_s_a] - self.Q_s_a[s_a])
return
def policy_improve_sarsa(self):
"""
策略提升
"""
for state in self.states:
if state in self.terminate_states:
continue
self.states_policy_action[state] = self.greedy_action(state)
return
def createGrids(self):
"""
创建网格
"""
start_x = 40
start_y = 40
line_length = 40
for state in self.states:
row, col = self.translateStateToRowCol(state)
x = start_x + col * line_length
y = start_y + row * line_length
line = rendering.Line((x, y), (x + line_length, y))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
line = rendering.Line((x, y), (x, y + line_length))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
line = rendering.Line((x + line_length, y), (x + line_length, y + line_length))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
line = rendering.Line((x, y + line_length), (x + line_length, y + line_length))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
def createTraps(self):
"""
创建陷阱,将黄金的位置也先绘制成陷阱,后面覆盖画成黄金
"""
start_x = 40
start_y = 40
line_length = 40
for state in self.terminate_states:
row, col = self.translateStateToRowCol(state)
trap = rendering.make_circle(20)
trans = rendering.Transform()
trap.add_attr(trans)
trap.set_color(0, 0, 0)
trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
self.viewer.add_onetime(trap)
def createGold(self):
"""
创建黄金,在这个问题中指的是出口
"""
start_x = 40
start_y = 40
line_length = 40
state = 15
row, col = self.translateStateToRowCol(state)
gold = rendering.make_circle(20)
trans = rendering.Transform()
gold.add_attr(trans)
gold.set_color(1, 0.9, 0)
trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
self.viewer.add_onetime(gold)
def createRobot(self):
"""
创建机器人
"""
start_x = 40
start_y = 40
line_length = 40
row, col = self.translateStateToRowCol(self.current_state)
robot = rendering.make_circle(15)
trans = rendering.Transform()
robot.add_attr(trans)
robot.set_color(1, 0, 1)
trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
self.viewer.add_onetime(robot)
def render(self, mode="human", close=False):
"""
渲染整个场景
"""
#关闭视图
if close:
if self.viewer is not None:
self.viewer.close()
self.viewer = None
#视图的大小
screen_width = 280
screen_height = 280
if self.viewer is None:
self.viewer = rendering.Viewer(screen_width, screen_height)
#创建网格
self.createGrids()
#创建陷阱
self.createTraps()
#创建黄金
self.createGold()
#创建机器人
self.createRobot()
return self.viewer.render(return_rgb_array= mode == 'rgb_array')
注册类到gym
from gym.envs.registration import register
try:
register(id = "GridWorld-v5", entry_point=GridWorldEnv, max_episode_steps = 200, reward_threshold=100.0)
except:
pass
动画模拟
from time import sleep
env = gym.make('GridWorld-v5')
env.reset()
#策略评估和策略改善
for _ in range(100000):
env.env.policy_evaluate_sarsa()
env.env.policy_improve_sarsa()
#观察env到底是个什么东西的打印信息。
print(isinstance(env, GridWorldEnv))
print(type(env))
print(env.__dict__)
print(isinstance(env.env, GridWorldEnv))
env.reset()
for _ in range(1000):
env.render()
if env.env.states_policy_action[env.env.current_state] is not None:
observation,reward,done,info = env.step(env.env.states_policy_action[env.env.current_state])
else:
done = True
print(_)
if done:
sleep(0.5)
env.render()
env.reset()
print("reset")
sleep(0.5)
env.close()