本文利用Q-learning训练机器人找金币(走迷宫),所使用的各种包均为最新版本,经过一系列的扩展现在只需指定迷宫的大小,陷阱所在格子编号以及终点格子编号,同时调整训练的次数,即可实现任意大小迷宫的训练。
一、创建迷宫网格世界的gym环境
1.一个 gym 的环境文件,其主体是一个类,在这里我们定义类名为:GridEnv1,其初始化为环境的基本参数,网格世界的全部代码在文件 grid_mdp1.py 中。
import logging
import random
import gym
from gym.envs.classic_control import rendering
from gym.envs.classic_control import Builder
logger = logging.getLogger(__name__)
class GridEnv1(gym.Env):
metadata = {
"render.modes": ["human", "rgb_array"],
"video.frames_per_second": 2
}
def __init__(self):
# 定义网格大小和特殊格子
self.grid_size = 10
self.trap_cells = {1, 36, 39, 59, 10, 60, 47, 48, 82, 56, 89, 91, 28, 94}
self.end_cell = 100
self.states = [] # 状态空间
self.n = self.grid_size * self.grid_size # 状态空间总数
for i in range(self.n):
self.states.append(i + 1)
self.terminate_states = dict() # 终止状态为字典格式
for i in self.trap_cells:
self.terminate_states[i] = 1
self.terminate_states[self.end_cell] = 1
self.actions = ['n', 'e', 's', 'w'] # 动作空间
self.action_space = gym.spaces.Discrete(4) # 定义动作空间
self.observation_space = gym.spaces.Discrete(self.n) # 定义观测空间
# 创建Builder实例
self.builder = Builder.Builder(self.grid_size, self.trap_cells, self.end_cell)
# 回报函数
self.rewards = dict() # 回报的数据结构为字典
self.sequences = self.builder.rewards_builder()
self.rewards = {key: value for key, value in self.sequences.items()}
# 状态转移概率
self.t = dict() # 状态转移的数据格式为字典
self.sequences = self.builder.t_builder()
self.t = {key: value for key, value in self.sequences.items()}
# best_qfunc1生成
self.sequences = self.builder.best_qfunc_builder()
self.sequences_outputs = [f"{key}:{value:.6f}" for key, value in self.sequences.items()]
self.file_name = 'best_qfunc1'
with open(self.file_name, 'w') as f:
for item in self.sequences_outputs:
f.write(f"{item}\n")
self.gamma = 0.8 # 折扣因子
self.viewer = None
self.state = None
def getTerminal(self):
return self.terminate_states
def getGamma(self):
return self.gamma
def getStates(self):
return self.states
def getAction(self):
return self.actions
def getTerminate_states(self):
return self.terminate_states
def setAction(self, s):
self.state = s
def step(self, action):
# 系统当前状态
state = self.state
# 判断系统当前状态是否为终止状态
if state in self.terminate_states:
return state, 0, True, False, {}
key = "%d_%s" % (state, action) # 将状态和动作组成字典的键值
# 状态转移
if key in self.t:
next_state = self.t[key]
else:
next_state = state
self.state = next_state
is_terminal = False
if next_state in self.terminate_states:
is_terminal = True
if key not in self.rewards:
r = 0.0
else:
r = self.rewards[key]
return next_state, r, is_terminal, False, {}
def reset(self):
self.state = self.states[int(random.random() * len(self.states))]
return self.state
def render(self, mode='human', close=False):
screen_width = 100 * (self.grid_size + 2)
screen_height = 100 * (self.grid_size + 2)
if close:
if self.viewer is not None:
self.viewer.close()
self.viewer = None
return
if self.viewer is None:
self.viewer = rendering.Viewer(screen_width, screen_height)
# 创建网格世界
for i in range(100, 100 * (self.grid_size + 1) + 1, 100):
line = rendering.Line((i, 100 * (self.grid_size + 1)), (i, 100))
line.set_color(0, 0, 0)
self.viewer.add_geom(line)
for j in range(100, 100 * (self.grid_size + 1) + 1, 100):
line = rendering.Line((100, j), (100 * (self.grid_size + 1), j))
line.set_color(0, 0, 0)
self.viewer.add_geom(line)
# 添加障碍物、金币和机器人的图形
# 颜色
obstacle_color = (0, 0, 0)
gold_color = (1, 0.9, 0)
robot_color = (0.8, 0.6, 0.4)
# 坐标
# 每个状态处机器人位置的中心坐标
self.x = []
self.y = []
# {4, 9, 11, 12, 23, 24, 25}
self.obstacles = []
self.gold = []
flag = 1
for k in self.trap_cells:
count = 1
for i in range(self.grid_size):
for j in range(self.grid_size):
if flag:
self.x.append(150 + 100 * j)
self.y.append(self.grid_size * 100 + 50 - 100 * i)
if k == count:
self.obstacles.append((150 + 100 * j, self.grid_size * 100 + 50 - 100 * i))
if self.end_cell == count and flag:
self.gold.append((150 + 100 * j, self.grid_size * 100 + 50 - 100 * i))
count += 1
flag = 0
# 创建骷髅
for (x, y) in self.obstacles:
self._add_circle(x, y, 50, obstacle_color)
# 创建金币
for (x, y) in self.gold:
self._add_circle(x, y, 50, gold_color)
# 创建机器人
self.robot = rendering.make_circle(30)
self.robotrans = rendering.Transform()
self.robot.add_attr(self.robotrans)
self.robot.set_color(0.8, 0.6, 0.4)
self.viewer.add_geom(self.robot)
if self.state is None:
return None
self.robotrans.set_translation(self.x[self.state - 1], self.y[self.state - 1])
return self.viewer.render(return_rgb_array=mode == 'rgb_array')
def _add_circle(self, x, y, radius, color):
circle = rendering.make_circle(radius)
circletrans = rendering.Transform(translation=(x, y))
circle.add_attr(circletrans)
circle.set_color(*color)
self.viewer.add_geom(circle)
return circle
def close(self):
if self.viewer is not None:
self.viewer.close()
self.viewer = None
return
2.下面重点讲一讲如何将建好的环境进行注册, 以便通过 gym 的标准形式进行调用:
第一步:将我们自己的环境文件(我创建的文件名为 grid_mdp1.py)拷贝到你的 gym
安装目录/gym/gym/envs/classic_control 文件夹中。(拷贝在这个文件夹中因为要使用
rendering 模块。当然,也有其他办法。该方法不唯一)
使用rendering的方法:拷贝下面的代码到上面的文件夹中
"""
2D rendering framework
"""
from __future__ import division
import os
import six
import sys
if "Apple" in sys.version:
if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:
os.environ['DYLD_FALLBACK_LIBRARY_PATH'] += ':/usr/lib'
# (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite
from gym import error
try:
import pyglet
except ImportError as e:
raise ImportError('''
Cannot import pyglet.
HINT: you can install pyglet directly via 'pip install pyglet'.
But if you really just want to install all Gym dependencies and not have to think about it,
'pip install -e .[all]' or 'pip install gym[all]' will do it.
''')
try:
from pyglet.gl import *
except ImportError as e:
raise ImportError('''
Error occurred while running `from pyglet.gl import *`
HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'.
If you're running on a server, you may need a virtual frame buffer; something like this should work:
'xvfb-run -s \"-screen 0 1400x900x24\" python <your_script.py>'
''')
import math
import numpy as np
RAD2DEG = 57.29577951308232
def get_display(spec):
"""Convert a display specification (such as :0) into an actual Display
object.
Pyglet only supports multiple Displays on Linux.
"""
if spec is None:
return None
elif isinstance(spec, six.string_types):
return pyglet.canvas.Display(spec)
else:
raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec))
class Viewer(object):
def __init__(self, width, height, display=None):
display = get_display(display)
self.width = width
self.height = height
self.window = pyglet.window.Window(width=width, height=height, display=display)
self.window.on_close = self.window_closed_by_user
self.isopen = True
self.geoms = []
self.onetime_geoms = []
self.transform = Transform()
glEnable(GL_BLEND)
glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
def close(self):
self.window.close()
def window_closed_by_user(self):
self.isopen = False
def set_bounds(self, left, right, bottom, top):
assert right > left and top > bottom
scalex = self.width/(right-left)
scaley = self.height/(top-bottom)
self.transform = Transform(
translation=(-left*scalex, -bottom*scaley),
scale=(scalex, scaley))
def add_geom(self, geom):
self.geoms.append(geom)
def add_onetime(self, geom):
self.onetime_geoms.append(geom)
def render(self, return_rgb_array=False):
glClearColor(1,1,1,1)
self.window.clear()
self.window.switch_to()
self.window.dispatch_events()
self.transform.enable()
for geom in self.geoms:
geom.render()
for geom in self.onetime_geoms:
geom.render()
self.transform.disable()
arr = None
if return_rgb_array:
buffer = pyglet.image.get_buffer_manager().get_color_buffer()
image_data = buffer.get_image_data()
arr = np.frombuffer(image_data.data, dtype=np.uint8)
# In https://github.com/openai/gym-http-api/issues/2, we
# discovered that someone using Xmonad on Arch was having
# a window of size 598 x 398, though a 600 x 400 window
# was requested. (Guess Xmonad was preserving a pixel for
# the boundary.) So we use the buffer height/width rather
# than the requested one.
arr = arr.reshape(buffer.height, buffer.width, 4)
arr = arr[::-1,:,0:3]
self.window.flip()
self.onetime_geoms = []
return arr if return_rgb_array else self.isopen
# Convenience
def draw_circle(self, radius=10, res=30, filled=True, **attrs):
geom = make_circle(radius=radius, res=res, filled=filled)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def draw_polygon(self, v, filled=True, **attrs):
geom = make_polygon(v=v, filled=filled)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def draw_polyline(self, v, **attrs):
geom = make_polyline(v=v)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def draw_line(self, start, end, **attrs):
geom = Line(start, end)
_add_attrs(geom, attrs)
self.add_onetime(geom)
return geom
def get_array(self):
self.window.flip()
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
self.window.flip()
arr = np.fromstring(image_data.data, dtype=np.uint8, sep='')
arr = arr.reshape(self.height, self.width, 4)
return arr[::-1,:,0:3]
def __del__(self):
self.close()
def _add_attrs(geom, attrs):
if "color" in attrs:
geom.set_color(*attrs["color"])
if "linewidth" in attrs:
geom.set_linewidth(attrs["linewidth"])
class Geom(object):
def __init__(self):
self._color=Color((0, 0, 0, 1.0))
self.attrs = [self._color]
def render(self):
for attr in reversed(self.attrs):
attr.enable()
self.render1()
for attr in self.attrs:
attr.disable()
def render1(self):
raise NotImplementedError
def add_attr(self, attr):
self.attrs.append(attr)
def set_color(self, r, g, b):
self._color.vec4 = (r, g, b, 1)
class Attr(object):
def enable(self):
raise NotImplementedError
def disable(self):
pass
class Transform(Attr):
def __init__(self, translation=(0.0, 0.0), rotation=0.0, scale=(1,1)):
self.set_translation(*translation)
self.set_rotation(rotation)
self.set_scale(*scale)
def enable(self):
glPushMatrix()
glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint
glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
glScalef(self.scale[0], self.scale[1], 1)
def disable(self):
glPopMatrix()
def set_translation(self, newx, newy):
self.translation = (float(newx), float(newy))
def set_rotation(self, new):
self.rotation = float(new)
def set_scale(self, newx, newy):
self.scale = (float(newx), float(newy))
class Color(Attr):
def __init__(self, vec4):
self.vec4 = vec4
def enable(self):
glColor4f(*self.vec4)
class LineStyle(Attr):
def __init__(self, style):
self.style = style
def enable(self):
glEnable(GL_LINE_STIPPLE)
glLineStipple(1, self.style)
def disable(self):
glDisable(GL_LINE_STIPPLE)
class LineWidth(Attr):
def __init__(self, stroke):
self.stroke = stroke
def enable(self):
glLineWidth(self.stroke)
class Point(Geom):
def __init__(self):
Geom.__init__(self)
def render1(self):
glBegin(GL_POINTS) # draw point
glVertex3f(0.0, 0.0, 0.0)
glEnd()
class FilledPolygon(Geom):
def __init__(self, v):
Geom.__init__(self)
self.v = v
def render1(self):
if len(self.v) == 4 : glBegin(GL_QUADS)
elif len(self.v) > 4 : glBegin(GL_POLYGON)
else: glBegin(GL_TRIANGLES)
for p in self.v:
glVertex3f(p[0], p[1],0) # draw each vertex
glEnd()
def make_circle(radius=10, res=30, filled=True):
points = []
for i in range(res):
ang = 2*math.pi*i / res
points.append((math.cos(ang)*radius, math.sin(ang)*radius))
if filled:
return FilledPolygon(points)
else:
return PolyLine(points, True)
def make_polygon(v, filled=True):
if filled: return FilledPolygon(v)
else: return PolyLine(v, True)
def make_polyline(v):
return PolyLine(v, False)
def make_capsule(length, width):
l, r, t, b = 0, length, width/2, -width/2
box = make_polygon([(l,b), (l,t), (r,t), (r,b)])
circ0 = make_circle(width/2)
circ1 = make_circle(width/2)
circ1.add_attr(Transform(translation=(length, 0)))
geom = Compound([box, circ0, circ1])
return geom
class Compound(Geom):
def __init__(self, gs):
Geom.__init__(self)
self.gs = gs
for g in self.gs:
g.attrs = [a for a in g.attrs if not isinstance(a, Color)]
def render1(self):
for g in self.gs:
g.render()
class PolyLine(Geom):
def __init__(self, v, close):
Geom.__init__(self)
self.v = v
self.close = close
self.linewidth = LineWidth(1)
self.add_attr(self.linewidth)
def render1(self):
glBegin(GL_LINE_LOOP if self.close else GL_LINE_STRIP)
for p in self.v:
glVertex3f(p[0], p[1],0) # draw each vertex
glEnd()
def set_linewidth(self, x):
self.linewidth.stroke = x
class Line(Geom):
def __init__(self, start=(0.0, 0.0), end=(0.0, 0.0)):
Geom.__init__(self)
self.start = start
self.end = end
self.linewidth = LineWidth(1)
self.add_attr(self.linewidth)
def render1(self):
glBegin(GL_LINES)
glVertex2f(*self.start)
glVertex2f(*self.end)
glEnd()
class Image(Geom):
def __init__(self, fname, width, height):
Geom.__init__(self)
self.width = width
self.height = height
img = pyglet.image.load(fname)
self.img = img
self.flip = False
def render1(self):
self.img.blit(-self.width/2, -self.height/2, width=self.width, height=self.height)
# ================================================================
class SimpleImageViewer(object):
def __init__(self, display=None, maxwidth=500):
self.window = None
self.isopen = False
self.display = display
self.maxwidth = maxwidth
def imshow(self, arr):
if self.window is None:
height, width, _channels = arr.shape
if width > self.maxwidth:
scale = self.maxwidth / width
width = int(scale * width)
height = int(scale * height)
self.window = pyglet.window.Window(width=width, height=height,
display=self.display, vsync=False, resizable=True)
self.width = width
self.height = height
self.isopen = True
@self.window.event
def on_resize(width, height):
self.width = width
self.height = height
@self.window.event
def on_close():
self.isopen = False
assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
image = pyglet.image.ImageData(arr.shape[1], arr.shape[0],
'RGB', arr.tobytes(), pitch=arr.shape[1]*-3)
gl.glTexParameteri(gl.GL_TEXTURE_2D,
gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
texture = image.get_texture()
texture.width = self.width
texture.height = self.height
self.window.clear()
self.window.switch_to()
self.window.dispatch_events()
texture.blit(0, 0) # draw
self.window.flip()
def close(self):
if self.isopen and sys.meta_path:
# ^^^ check sys.meta_path to avoid 'ImportError: sys.meta_path is None, Python is likely shutting down'
self.window.close()
self.isopen = False
def __del__(self):
self.close()
需要注意的是,要成功调用rendering,你还需要安装下面两个依赖库:
pip install six
pip install pyglet==1.5.27
然后,再将我编写的Builder.py 拷贝到上面的文件夹中:
class Builder:
def __init__(self, grid_size, trap_cells, end_cell):
self.grid_size = grid_size
self.trap_cells = trap_cells
self.end_cell = end_cell
def t_builder(self):
sequences = {}
for cell in range(1, self.grid_size ** 2 + 1):
if cell not in self.trap_cells and cell != self.end_cell:
if cell > self.grid_size:
sequences[f"{cell}_n"] = cell - self.grid_size
if cell % self.grid_size != 0:
sequences[f"{cell}_e"] = cell + 1
if cell + self.grid_size <= self.grid_size ** 2:
sequences[f"{cell}_s"] = cell + self.grid_size
if (cell - 1) % self.grid_size != 0:
sequences[f"{cell}_w"] = cell - 1
return sequences
def rewards_builder(self):
sequences = {}
for cell in range(1, self.grid_size * self.grid_size + 1):
if cell not in self.trap_cells and cell != self.end_cell:
for direction in ['n', 'e', 's', 'w']:
if direction == 'n' and cell > self.grid_size:
target_cell = cell - self.grid_size
elif direction == 'e' and cell % self.grid_size != 0:
target_cell = cell + 1
elif direction == 's' and cell <= self.grid_size * (self.grid_size - 1):
target_cell = cell + self.grid_size
elif direction == 'w' and cell % self.grid_size != 1:
target_cell = cell - 1
else:
continue
if target_cell in self.trap_cells or target_cell == self.end_cell:
value = -1.0 if target_cell in self.trap_cells else 1.0
sequences[f"{cell}_{direction}"] = value
return sequences
def best_qfunc_builder(self):
sequences = {}
for cell in range(1, self.grid_size * self.grid_size + 1):
if cell in self.trap_cells or cell == self.end_cell:
sequences[f"{cell}_n"] = 0.000000
sequences[f"{cell}_e"] = 0.000000
sequences[f"{cell}_s"] = 0.000000
sequences[f"{cell}_w"] = 0.000000
else:
for direction in ['n', 'e', 's', 'w']:
if direction == 'n' and cell > self.grid_size:
target_cell = cell - self.grid_size
elif direction == 'e' and cell % self.grid_size != 0:
target_cell = cell + 1
elif direction == 's' and cell <= self.grid_size * (self.grid_size - 1):
target_cell = cell + self.grid_size
elif direction == 'w' and cell % self.grid_size != 1:
target_cell = cell - 1
else:
target_cell = cell
if target_cell == cell:
value = 0.512000
elif target_cell in self.trap_cells:
value = -1.000000
elif target_cell == self.end_cell:
value = 1.000000
else:
value = 0.640000
sequences[f"{cell}_{direction}"] = value
return sequences
""" 测试
# 定义网格大小和特殊格子
grid_size = 5
trap_cells = {4, 9, 11, 12, 23, 24, 25}
end_cell = 15
# 创建Builder实例
builder = Builder(grid_size, trap_cells, end_cell)
# 调用不同方法生成序列并打印结果
print("状态转移t:")
sequences = builder.t_builder()
sequences_outputs = [f"self.t['{key}'] = {value}" for key, value in sequences.items()]
for i in range(len(sequences_outputs)):
print(sequences_outputs[i])
print("奖励rewards:")
sequences = builder.rewards_builder()
sequences_outputs = [f"self.rewards['{key}'] = {value}" for key, value in sequences.items()]
for i in range(len(sequences_outputs)):
print(sequences_outputs[i])
print("\nbest_qfunc:")
sequences = builder.best_qfunc_builder()
sequences_outputs = [f"{key}:{value:.6f}" for key, value in sequences.items()]
file_name = 'best_qfunc2'
with open(file_name, 'w') as f:
for item in sequences_outputs:
f.write(f"{item}\n")
"""
第二步:打开该文件夹(第一步中的文件夹)下的__init__.py 文件,在文件末尾加入语句:
from gym.envs.classic_control.grid_mdp1 import GridEnv1
第三步:进入文件夹你的 gym 安装目录/gym/gym/envs(即第一步中的文件夹的上一级文件夹),打开该文件夹下的 __init__.py 文件,添加代码:
register(
id='GridWorld-v1',
entry_point='gym.envs.classic_control.grid_mdp1:GridEnv1',
max_episode_steps=200,
reward_threshold=100.0,
)
第一个参数 id 就是你调用 gym.make(‘id’)时的 id, 这个 id 你可以随便选取,我取的名字是GridWorld-v1。
第二个参数就是函数路口了。 后面的参数原则上来说可以不必要写。
经过以上三步,就完成了注册。
3.测试
打开Anaconda Prompt,激活你的虚拟环境,依次输入下面每行代码:
python
import gym
env = gym.make('GridWorld-v1')
env.reset()
env.render()
如果没有问题,就会显示一张环境图片(emmm,它有点大,失策了,得调整一下尺寸) 。
最后输入env.close(),关闭环境。
二、Q-learning算法
Qlearning的所有代码放在如下qlearning.py中,如下:
import gym
import random
random.seed(0) # 设置随机数种子,使每次产生的随机数序列都相同
import matplotlib.pyplot as plt
grid = gym.make('GridWorld-v1')
# 创建网格世界
states = grid.env.getStates() # 获得网格世界的状态空间
actions = grid.env.getAction() # 获得网格世界的动作空间
gamma = grid.env.getGamma() # 获得折扣因子
# 计算当前策略和最优策略之间的差
best = dict() # 储存最优行为值函数
def read_best():
f = open("best_qfunc1")
for line in f:
line = line.strip()
if len(line) == 0:
continue
eles = line.split(":")
best[eles[0]] = float(eles[1])
# 计算值函数的误差
def compute_error(qfunc):
sum1 = 0.0
for key in qfunc:
error = qfunc[key] - best[key]
sum1 += error * error
return sum1
# 贪婪策略
def greedy(qfunc, state):
amax = 0
key = "%d_%s" % (state, actions[0])
qmax = qfunc[key]
for i in range(len(actions)): # 扫描动作空间得到最大动作值函数
key = "%d_%s" % (state, actions[i])
q = qfunc[key]
if qmax < q:
qmax = q
amax = i
return actions[amax]
# ######epsilon贪婪策略
def epsilon_greedy(qfunc, state, epsilon):
amax = 0
key = "%d_%s" % (state, actions[0])
qmax = qfunc[key]
for i in range(len(actions)): # 扫描动作空间得到最大动作值函数
key = "%d_%s" % (state, actions[i])
q = qfunc[key]
if qmax < q:
qmax = q
amax = i
# 概率部分
pro = [0.0 for i in range(len(actions))]
pro[amax] += 1 - epsilon
for i in range(len(actions)):
pro[i] += epsilon / len(actions)
# #选择动作
r = random.random() # 生成0到1之间的随机浮点数
s = 0.0
for i in range(len(actions)):
s += pro[i]
if s >= r:
return actions[i]
return actions[len(actions) - 1]
def qlearning(num_iter1, alpha, epsilon):
x = []
y = []
qfunc = dict() # 行为值函数为字典
# 初始化行为值函数为0
for s in states:
for a in actions:
key = "%d_%s" % (s, a)
qfunc[key] = 0.0
for iter1 in range(num_iter1):
x.append(iter1)
y.append(compute_error(qfunc))
# 初始化初始状态
s = grid.reset()
a = actions[int(random.random() * len(actions))]
t = False
count = 0
while t is False and count < 100:
key = "%d_%s" % (s, a)
# 与环境进行一次交互,从环境中得到新的状态及回报
s1, r, t, _, _ = grid.step(a)
# s1处的最大动作
a1 = greedy(qfunc, s1)
key1 = "%d_%s" % (s1, a1)
# 利用qlearning方法更新值函数
qfunc[key] = qfunc[key] + alpha * (r + gamma * qfunc[key1] - qfunc[key])
# 转到下一个状态
s = s1
a = epsilon_greedy(qfunc, s1, epsilon)
count += 1
plt.plot(x, y, "-.,", label="q alpha=%2.1f epsilon=%2.1f" % (alpha, epsilon))
return qfunc
三、训练及测试代码
训练及测试的所有代码放在learning_and_test.py中,代码如下,其中num_iter1是训练次数,请随迷宫的尺寸变化自行调整其大小,以使误差函数收敛:
from qlearning import *
import time
# main函数
if __name__ == "__main__":
sleeptime = 0.5
terminate_states = grid.env.getTerminate_states()
# 读入最优值函数
read_best()
# 训练
qfunc = qlearning(num_iter1=4000, alpha=0.2, epsilon=0.2)
# 画图
plt.xlabel("number of iterations")
plt.ylabel("square errors")
plt.legend()
# 显示误差图像
plt.show()
time.sleep(sleeptime)
# 学到的值函数
for s in states:
for a in actions:
key = "%d_%s" % (s, a)
print("the qfunc of key (%s) is %f" % (key, qfunc[key]))
# qfunc[key]
# 学到的策略为:
print("the learned policy is:")
for i in range(len(states)):
if states[i] in terminate_states:
print("the state %d is terminate_states" % (states[i]))
else:
print("the policy of state %d is (%s)" % (states[i], greedy(qfunc, states[i])))
# 设置系统初始状态
s0 = 1
grid.env.setAction(s0)
# 对训练好的策略进行测试
# 随机初始化,寻找金币的路径
for i in range(20):
# 随机初始化
s0 = grid.reset()
grid.render()
time.sleep(sleeptime)
t = False
count = 0
# 判断随机状态是否在终止状态中
if s0 in terminate_states:
print("reach the terminate state %d" % s0)
else:
while t is False and count < 100:
a1 = greedy(qfunc, s0)
print(s0, a1)
grid.render()
time.sleep(sleeptime)
key = "%d_%s" % (s0, a1)
# 与环境进行一次交互,从环境中得到新的状态及回报
s1, r, t, _, i = grid.step(a1)
if t is True:
# 打印终止状态
print(s1)
grid.render()
time.sleep(sleeptime)
print("reach the terminate state %d" % s1)
# s1处的最大动作
s0 = s1
count += 1
grid.close()
请调整20-22行代码以创建你想要的迷宫:
将learning_and_test.py和qlearning.py放在同一个文件夹下,即可运行 learning_and_test.py进行训练及测试。代码运行过程中会有一系列的logger.warn(报错,不用管。
四、需要修改的点
该代码的损失收敛曲线计算方式并不是很棒,不应该使用best_qfunc来计算,需要再做修改