强化学习Q-Learning训练机器人找金币(走迷宫)

本文利用Q-learning训练机器人找金币(走迷宫),所使用的各种包均为最新版本,经过一系列的扩展现在只需指定迷宫的大小,陷阱所在格子编号以及终点格子编号,同时调整训练的次数,即可实现任意大小迷宫的训练。

一、创建迷宫网格世界的gym环境

1.一个 gym 的环境文件,其主体是一个类,在这里我们定义类名为:GridEnv1,其初始化为环境的基本参数,网格世界的全部代码在文件 grid_mdp1.py 中。

import logging
import random

import gym
from gym.envs.classic_control import rendering
from gym.envs.classic_control import Builder

logger = logging.getLogger(__name__)


class GridEnv1(gym.Env):
    metadata = {
        "render.modes": ["human", "rgb_array"],
        "video.frames_per_second": 2
    }

    def __init__(self):

        # 定义网格大小和特殊格子
        self.grid_size = 10
        self.trap_cells = {1, 36, 39, 59, 10, 60, 47, 48, 82, 56, 89, 91, 28, 94}
        self.end_cell = 100

        self.states = []  # 状态空间
        self.n = self.grid_size * self.grid_size  # 状态空间总数
        for i in range(self.n):
            self.states.append(i + 1)

        self.terminate_states = dict()  # 终止状态为字典格式
        for i in self.trap_cells:
            self.terminate_states[i] = 1
        self.terminate_states[self.end_cell] = 1

        self.actions = ['n', 'e', 's', 'w']  # 动作空间

        self.action_space = gym.spaces.Discrete(4)  # 定义动作空间
        self.observation_space = gym.spaces.Discrete(self.n)  # 定义观测空间

        # 创建Builder实例
        self.builder = Builder.Builder(self.grid_size, self.trap_cells, self.end_cell)
        # 回报函数
        self.rewards = dict()  # 回报的数据结构为字典
        self.sequences = self.builder.rewards_builder()
        self.rewards = {key: value for key, value in self.sequences.items()}

        # 状态转移概率
        self.t = dict()  # 状态转移的数据格式为字典
        self.sequences = self.builder.t_builder()
        self.t = {key: value for key, value in self.sequences.items()}

        # best_qfunc1生成
        self.sequences = self.builder.best_qfunc_builder()
        self.sequences_outputs = [f"{key}:{value:.6f}" for key, value in self.sequences.items()]
        self.file_name = 'best_qfunc1'
        with open(self.file_name, 'w') as f:
            for item in self.sequences_outputs:
                f.write(f"{item}\n")

        self.gamma = 0.8  # 折扣因子
        self.viewer = None
        self.state = None

    def getTerminal(self):
        return self.terminate_states

    def getGamma(self):
        return self.gamma

    def getStates(self):
        return self.states

    def getAction(self):
        return self.actions

    def getTerminate_states(self):
        return self.terminate_states

    def setAction(self, s):
        self.state = s

    def step(self, action):
        # 系统当前状态
        state = self.state
        # 判断系统当前状态是否为终止状态
        if state in self.terminate_states:
            return state, 0, True, False, {}
        key = "%d_%s" % (state, action)  # 将状态和动作组成字典的键值

        # 状态转移
        if key in self.t:
            next_state = self.t[key]
        else:
            next_state = state
        self.state = next_state

        is_terminal = False

        if next_state in self.terminate_states:
            is_terminal = True

        if key not in self.rewards:
            r = 0.0
        else:
            r = self.rewards[key]

        return next_state, r, is_terminal, False, {}

    def reset(self):
        self.state = self.states[int(random.random() * len(self.states))]
        return self.state

    def render(self, mode='human', close=False):
        screen_width = 100 * (self.grid_size + 2)
        screen_height = 100 * (self.grid_size + 2)

        if close:
            if self.viewer is not None:
                self.viewer.close()
                self.viewer = None
            return

        if self.viewer is None:
            self.viewer = rendering.Viewer(screen_width, screen_height)
            # 创建网格世界
            for i in range(100, 100 * (self.grid_size + 1) + 1, 100):
                line = rendering.Line((i, 100 * (self.grid_size + 1)), (i, 100))
                line.set_color(0, 0, 0)
                self.viewer.add_geom(line)
            for j in range(100, 100 * (self.grid_size + 1) + 1, 100):
                line = rendering.Line((100, j), (100 * (self.grid_size + 1), j))
                line.set_color(0, 0, 0)
                self.viewer.add_geom(line)

            # 添加障碍物、金币和机器人的图形
            # 颜色
            obstacle_color = (0, 0, 0)
            gold_color = (1, 0.9, 0)
            robot_color = (0.8, 0.6, 0.4)
            # 坐标
            # 每个状态处机器人位置的中心坐标
            self.x = []
            self.y = []
            # {4, 9, 11, 12, 23, 24, 25}
            self.obstacles = []
            self.gold = []
            flag = 1
            for k in self.trap_cells:
                count = 1
                for i in range(self.grid_size):
                    for j in range(self.grid_size):
                        if flag:
                            self.x.append(150 + 100 * j)
                            self.y.append(self.grid_size * 100 + 50 - 100 * i)
                        if k == count:
                            self.obstacles.append((150 + 100 * j, self.grid_size * 100 + 50 - 100 * i))
                        if self.end_cell == count and flag:
                            self.gold.append((150 + 100 * j, self.grid_size * 100 + 50 - 100 * i))
                        count += 1
                flag = 0
            # 创建骷髅
            for (x, y) in self.obstacles:
                self._add_circle(x, y, 50, obstacle_color)
            # 创建金币
            for (x, y) in self.gold:
                self._add_circle(x, y, 50, gold_color)
            # 创建机器人
            self.robot = rendering.make_circle(30)
            self.robotrans = rendering.Transform()
            self.robot.add_attr(self.robotrans)
            self.robot.set_color(0.8, 0.6, 0.4)
            self.viewer.add_geom(self.robot)

        if self.state is None:
            return None
        self.robotrans.set_translation(self.x[self.state - 1], self.y[self.state - 1])

        return self.viewer.render(return_rgb_array=mode == 'rgb_array')

    def _add_circle(self, x, y, radius, color):
        circle = rendering.make_circle(radius)
        circletrans = rendering.Transform(translation=(x, y))
        circle.add_attr(circletrans)
        circle.set_color(*color)
        self.viewer.add_geom(circle)
        return circle

    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None
        return

2.下面重点讲一讲如何将建好的环境进行注册, 以便通过 gym 的标准形式进行调用:

第一步:将我们自己的环境文件(我创建的文件名为 grid_mdp1.py)拷贝到你的 gym
安装目录/gym/gym/envs/classic_control 文件夹中。(拷贝在这个文件夹中因为要使用
rendering 模块。当然,也有其他办法。该方法不唯一)
使用rendering的方法:拷贝下面的代码到上面的文件夹中
 
"""
2D rendering framework
"""
from __future__ import division
import os
import six
import sys

if "Apple" in sys.version:
    if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:
        os.environ['DYLD_FALLBACK_LIBRARY_PATH'] += ':/usr/lib'
        # (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite

from gym import error

try:
    import pyglet
except ImportError as e:
    raise ImportError('''
    Cannot import pyglet.
    HINT: you can install pyglet directly via 'pip install pyglet'.
    But if you really just want to install all Gym dependencies and not have to think about it,
    'pip install -e .[all]' or 'pip install gym[all]' will do it.
    ''')

try:
    from pyglet.gl import *
except ImportError as e:
    raise ImportError('''
    Error occurred while running `from pyglet.gl import *`
    HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'.
    If you're running on a server, you may need a virtual frame buffer; something like this should work:
    'xvfb-run -s \"-screen 0 1400x900x24\" python <your_script.py>'
    ''')

import math
import numpy as np

RAD2DEG = 57.29577951308232

def get_display(spec):
    """Convert a display specification (such as :0) into an actual Display
    object.

    Pyglet only supports multiple Displays on Linux.
    """
    if spec is None:
        return None
    elif isinstance(spec, six.string_types):
        return pyglet.canvas.Display(spec)
    else:
        raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec))

class Viewer(object):
    def __init__(self, width, height, display=None):
        display = get_display(display)

        self.width = width
        self.height = height
        self.window = pyglet.window.Window(width=width, height=height, display=display)
        self.window.on_close = self.window_closed_by_user
        self.isopen = True
        self.geoms = []
        self.onetime_geoms = []
        self.transform = Transform()

        glEnable(GL_BLEND)
        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)

    def close(self):
        self.window.close()

    def window_closed_by_user(self):
        self.isopen = False

    def set_bounds(self, left, right, bottom, top):
        assert right > left and top > bottom
        scalex = self.width/(right-left)
        scaley = self.height/(top-bottom)
        self.transform = Transform(
            translation=(-left*scalex, -bottom*scaley),
            scale=(scalex, scaley))

    def add_geom(self, geom):
        self.geoms.append(geom)

    def add_onetime(self, geom):
        self.onetime_geoms.append(geom)

    def render(self, return_rgb_array=False):
        glClearColor(1,1,1,1)
        self.window.clear()
        self.window.switch_to()
        self.window.dispatch_events()
        self.transform.enable()
        for geom in self.geoms:
            geom.render()
        for geom in self.onetime_geoms:
            geom.render()
        self.transform.disable()
        arr = None
        if return_rgb_array:
            buffer = pyglet.image.get_buffer_manager().get_color_buffer()
            image_data = buffer.get_image_data()
            arr = np.frombuffer(image_data.data, dtype=np.uint8)
            # In https://github.com/openai/gym-http-api/issues/2, we
            # discovered that someone using Xmonad on Arch was having
            # a window of size 598 x 398, though a 600 x 400 window
            # was requested. (Guess Xmonad was preserving a pixel for
            # the boundary.) So we use the buffer height/width rather
            # than the requested one.
            arr = arr.reshape(buffer.height, buffer.width, 4)
            arr = arr[::-1,:,0:3]
        self.window.flip()
        self.onetime_geoms = []
        return arr if return_rgb_array else self.isopen

    # Convenience
    def draw_circle(self, radius=10, res=30, filled=True, **attrs):
        geom = make_circle(radius=radius, res=res, filled=filled)
        _add_attrs(geom, attrs)
        self.add_onetime(geom)
        return geom

    def draw_polygon(self, v, filled=True, **attrs):
        geom = make_polygon(v=v, filled=filled)
        _add_attrs(geom, attrs)
        self.add_onetime(geom)
        return geom

    def draw_polyline(self, v, **attrs):
        geom = make_polyline(v=v)
        _add_attrs(geom, attrs)
        self.add_onetime(geom)
        return geom

    def draw_line(self, start, end, **attrs):
        geom = Line(start, end)
        _add_attrs(geom, attrs)
        self.add_onetime(geom)
        return geom

    def get_array(self):
        self.window.flip()
        image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
        self.window.flip()
        arr = np.fromstring(image_data.data, dtype=np.uint8, sep='')
        arr = arr.reshape(self.height, self.width, 4)
        return arr[::-1,:,0:3]

    def __del__(self):
        self.close()

def _add_attrs(geom, attrs):
    if "color" in attrs:
        geom.set_color(*attrs["color"])
    if "linewidth" in attrs:
        geom.set_linewidth(attrs["linewidth"])

class Geom(object):
    def __init__(self):
        self._color=Color((0, 0, 0, 1.0))
        self.attrs = [self._color]
    def render(self):
        for attr in reversed(self.attrs):
            attr.enable()
        self.render1()
        for attr in self.attrs:
            attr.disable()
    def render1(self):
        raise NotImplementedError
    def add_attr(self, attr):
        self.attrs.append(attr)
    def set_color(self, r, g, b):
        self._color.vec4 = (r, g, b, 1)

class Attr(object):
    def enable(self):
        raise NotImplementedError
    def disable(self):
        pass

class Transform(Attr):
    def __init__(self, translation=(0.0, 0.0), rotation=0.0, scale=(1,1)):
        self.set_translation(*translation)
        self.set_rotation(rotation)
        self.set_scale(*scale)
    def enable(self):
        glPushMatrix()
        glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint
        glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
        glScalef(self.scale[0], self.scale[1], 1)
    def disable(self):
        glPopMatrix()
    def set_translation(self, newx, newy):
        self.translation = (float(newx), float(newy))
    def set_rotation(self, new):
        self.rotation = float(new)
    def set_scale(self, newx, newy):
        self.scale = (float(newx), float(newy))

class Color(Attr):
    def __init__(self, vec4):
        self.vec4 = vec4
    def enable(self):
        glColor4f(*self.vec4)

class LineStyle(Attr):
    def __init__(self, style):
        self.style = style
    def enable(self):
        glEnable(GL_LINE_STIPPLE)
        glLineStipple(1, self.style)
    def disable(self):
        glDisable(GL_LINE_STIPPLE)

class LineWidth(Attr):
    def __init__(self, stroke):
        self.stroke = stroke
    def enable(self):
        glLineWidth(self.stroke)

class Point(Geom):
    def __init__(self):
        Geom.__init__(self)
    def render1(self):
        glBegin(GL_POINTS) # draw point
        glVertex3f(0.0, 0.0, 0.0)
        glEnd()

class FilledPolygon(Geom):
    def __init__(self, v):
        Geom.__init__(self)
        self.v = v
    def render1(self):
        if   len(self.v) == 4 : glBegin(GL_QUADS)
        elif len(self.v)  > 4 : glBegin(GL_POLYGON)
        else: glBegin(GL_TRIANGLES)
        for p in self.v:
            glVertex3f(p[0], p[1],0)  # draw each vertex
        glEnd()

def make_circle(radius=10, res=30, filled=True):
    points = []
    for i in range(res):
        ang = 2*math.pi*i / res
        points.append((math.cos(ang)*radius, math.sin(ang)*radius))
    if filled:
        return FilledPolygon(points)
    else:
        return PolyLine(points, True)

def make_polygon(v, filled=True):
    if filled: return FilledPolygon(v)
    else: return PolyLine(v, True)

def make_polyline(v):
    return PolyLine(v, False)

def make_capsule(length, width):
    l, r, t, b = 0, length, width/2, -width/2
    box = make_polygon([(l,b), (l,t), (r,t), (r,b)])
    circ0 = make_circle(width/2)
    circ1 = make_circle(width/2)
    circ1.add_attr(Transform(translation=(length, 0)))
    geom = Compound([box, circ0, circ1])
    return geom

class Compound(Geom):
    def __init__(self, gs):
        Geom.__init__(self)
        self.gs = gs
        for g in self.gs:
            g.attrs = [a for a in g.attrs if not isinstance(a, Color)]
    def render1(self):
        for g in self.gs:
            g.render()

class PolyLine(Geom):
    def __init__(self, v, close):
        Geom.__init__(self)
        self.v = v
        self.close = close
        self.linewidth = LineWidth(1)
        self.add_attr(self.linewidth)
    def render1(self):
        glBegin(GL_LINE_LOOP if self.close else GL_LINE_STRIP)
        for p in self.v:
            glVertex3f(p[0], p[1],0)  # draw each vertex
        glEnd()
    def set_linewidth(self, x):
        self.linewidth.stroke = x

class Line(Geom):
    def __init__(self, start=(0.0, 0.0), end=(0.0, 0.0)):
        Geom.__init__(self)
        self.start = start
        self.end = end
        self.linewidth = LineWidth(1)
        self.add_attr(self.linewidth)

    def render1(self):
        glBegin(GL_LINES)
        glVertex2f(*self.start)
        glVertex2f(*self.end)
        glEnd()

class Image(Geom):
    def __init__(self, fname, width, height):
        Geom.__init__(self)
        self.width = width
        self.height = height
        img = pyglet.image.load(fname)
        self.img = img
        self.flip = False
    def render1(self):
        self.img.blit(-self.width/2, -self.height/2, width=self.width, height=self.height)

# ================================================================

class SimpleImageViewer(object):
    def __init__(self, display=None, maxwidth=500):
        self.window = None
        self.isopen = False
        self.display = display
        self.maxwidth = maxwidth
    def imshow(self, arr):
        if self.window is None:
            height, width, _channels = arr.shape
            if width > self.maxwidth:
                scale = self.maxwidth / width
                width = int(scale * width)
                height = int(scale * height)
            self.window = pyglet.window.Window(width=width, height=height,
                display=self.display, vsync=False, resizable=True)
            self.width = width
            self.height = height
            self.isopen = True

            @self.window.event
            def on_resize(width, height):
                self.width = width
                self.height = height

            @self.window.event
            def on_close():
                self.isopen = False

        assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
        image = pyglet.image.ImageData(arr.shape[1], arr.shape[0],
            'RGB', arr.tobytes(), pitch=arr.shape[1]*-3)
        gl.glTexParameteri(gl.GL_TEXTURE_2D,
            gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
        texture = image.get_texture()
        texture.width = self.width
        texture.height = self.height
        self.window.clear()
        self.window.switch_to()
        self.window.dispatch_events()
        texture.blit(0, 0) # draw
        self.window.flip()
    def close(self):
        if self.isopen and sys.meta_path:
            # ^^^ check sys.meta_path to avoid 'ImportError: sys.meta_path is None, Python is likely shutting down'
            self.window.close()
            self.isopen = False

    def __del__(self):
        self.close()

需要注意的是,要成功调用rendering,你还需要安装下面两个依赖库:

pip install six
pip install pyglet==1.5.27

然后,再将我编写的Builder.py 拷贝到上面的文件夹中:

class Builder:
    def __init__(self, grid_size, trap_cells, end_cell):
        self.grid_size = grid_size
        self.trap_cells = trap_cells
        self.end_cell = end_cell

    def t_builder(self):
        sequences = {}
        for cell in range(1, self.grid_size ** 2 + 1):
            if cell not in self.trap_cells and cell != self.end_cell:
                if cell > self.grid_size:
                    sequences[f"{cell}_n"] = cell - self.grid_size
                if cell % self.grid_size != 0:
                    sequences[f"{cell}_e"] = cell + 1
                if cell + self.grid_size <= self.grid_size ** 2:
                    sequences[f"{cell}_s"] = cell + self.grid_size
                if (cell - 1) % self.grid_size != 0:
                    sequences[f"{cell}_w"] = cell - 1
        return sequences

    def rewards_builder(self):
        sequences = {}
        for cell in range(1, self.grid_size * self.grid_size + 1):
            if cell not in self.trap_cells and cell != self.end_cell:
                for direction in ['n', 'e', 's', 'w']:
                    if direction == 'n' and cell > self.grid_size:
                        target_cell = cell - self.grid_size
                    elif direction == 'e' and cell % self.grid_size != 0:
                        target_cell = cell + 1
                    elif direction == 's' and cell <= self.grid_size * (self.grid_size - 1):
                        target_cell = cell + self.grid_size
                    elif direction == 'w' and cell % self.grid_size != 1:
                        target_cell = cell - 1
                    else:
                        continue

                    if target_cell in self.trap_cells or target_cell == self.end_cell:
                        value = -1.0 if target_cell in self.trap_cells else 1.0
                        sequences[f"{cell}_{direction}"] = value
        return sequences

    def best_qfunc_builder(self):
        sequences = {}
        for cell in range(1, self.grid_size * self.grid_size + 1):
            if cell in self.trap_cells or cell == self.end_cell:
                sequences[f"{cell}_n"] = 0.000000
                sequences[f"{cell}_e"] = 0.000000
                sequences[f"{cell}_s"] = 0.000000
                sequences[f"{cell}_w"] = 0.000000

            else:
                for direction in ['n', 'e', 's', 'w']:
                    if direction == 'n' and cell > self.grid_size:
                        target_cell = cell - self.grid_size
                    elif direction == 'e' and cell % self.grid_size != 0:
                        target_cell = cell + 1
                    elif direction == 's' and cell <= self.grid_size * (self.grid_size - 1):
                        target_cell = cell + self.grid_size
                    elif direction == 'w' and cell % self.grid_size != 1:
                        target_cell = cell - 1
                    else:
                        target_cell = cell

                    if target_cell == cell:
                        value = 0.512000
                    elif target_cell in self.trap_cells:
                        value = -1.000000
                    elif target_cell == self.end_cell:
                        value = 1.000000
                    else:
                        value = 0.640000

                    sequences[f"{cell}_{direction}"] = value
        return sequences


""" 测试
# 定义网格大小和特殊格子
grid_size = 5
trap_cells = {4, 9, 11, 12, 23, 24, 25}
end_cell = 15

# 创建Builder实例
builder = Builder(grid_size, trap_cells, end_cell)

# 调用不同方法生成序列并打印结果
print("状态转移t:")
sequences = builder.t_builder()
sequences_outputs = [f"self.t['{key}'] = {value}" for key, value in sequences.items()]
for i in range(len(sequences_outputs)):
    print(sequences_outputs[i])

print("奖励rewards:")
sequences = builder.rewards_builder()
sequences_outputs = [f"self.rewards['{key}'] = {value}" for key, value in sequences.items()]
for i in range(len(sequences_outputs)):
    print(sequences_outputs[i])

print("\nbest_qfunc:")
sequences = builder.best_qfunc_builder()
sequences_outputs = [f"{key}:{value:.6f}" for key, value in sequences.items()]
file_name = 'best_qfunc2'
with open(file_name, 'w') as f:
    for item in sequences_outputs:
        f.write(f"{item}\n")
"""
第二步:打开该文件夹(第一步中的文件夹)下的__init__.py 文件,在文件末尾加入语句:
from gym.envs.classic_control.grid_mdp1 import GridEnv1

第三步:进入文件夹你的 gym 安装目录/gym/gym/envs(即第一步中的文件夹的上一级文件夹),打开该文件夹下的 __init__.py 文件,添加代码:

register(
    id='GridWorld-v1',
    entry_point='gym.envs.classic_control.grid_mdp1:GridEnv1',
    max_episode_steps=200,
    reward_threshold=100.0,
)

第一个参数 id 就是你调用 gym.make(‘id’)时的 id, 这个 id 你可以随便选取,我取的名字是GridWorld-v1。
第二个参数就是函数路口了。 后面的参数原则上来说可以不必要写。
经过以上三步,就完成了注册。

3.测试

打开Anaconda Prompt,激活你的虚拟环境,依次输入下面每行代码:

python
import gym
env = gym.make('GridWorld-v1')
env.reset()
env.render()

如果没有问题,就会显示一张环境图片(emmm,它有点大,失策了,得调整一下尺寸) 。

最后输入env.close(),关闭环境。 

二、Q-learning算法

Qlearning的所有代码放在如下qlearning.py中,如下:

import gym
import random

random.seed(0)  # 设置随机数种子,使每次产生的随机数序列都相同
import matplotlib.pyplot as plt

grid = gym.make('GridWorld-v1')
# 创建网格世界
states = grid.env.getStates()  # 获得网格世界的状态空间
actions = grid.env.getAction()  # 获得网格世界的动作空间
gamma = grid.env.getGamma()  # 获得折扣因子
# 计算当前策略和最优策略之间的差
best = dict()  # 储存最优行为值函数


def read_best():
    f = open("best_qfunc1")
    for line in f:
        line = line.strip()
        if len(line) == 0:
            continue
        eles = line.split(":")
        best[eles[0]] = float(eles[1])


# 计算值函数的误差
def compute_error(qfunc):
    sum1 = 0.0
    for key in qfunc:
        error = qfunc[key] - best[key]
        sum1 += error * error
    return sum1


#  贪婪策略
def greedy(qfunc, state):
    amax = 0
    key = "%d_%s" % (state, actions[0])
    qmax = qfunc[key]
    for i in range(len(actions)):  # 扫描动作空间得到最大动作值函数
        key = "%d_%s" % (state, actions[i])
        q = qfunc[key]
        if qmax < q:
            qmax = q
            amax = i
    return actions[amax]


# ######epsilon贪婪策略
def epsilon_greedy(qfunc, state, epsilon):
    amax = 0
    key = "%d_%s" % (state, actions[0])
    qmax = qfunc[key]
    for i in range(len(actions)):  # 扫描动作空间得到最大动作值函数
        key = "%d_%s" % (state, actions[i])
        q = qfunc[key]
        if qmax < q:
            qmax = q
            amax = i
    # 概率部分
    pro = [0.0 for i in range(len(actions))]
    pro[amax] += 1 - epsilon
    for i in range(len(actions)):
        pro[i] += epsilon / len(actions)

    # #选择动作
    r = random.random()     # 生成0到1之间的随机浮点数
    s = 0.0
    for i in range(len(actions)):
        s += pro[i]
        if s >= r:
            return actions[i]
    return actions[len(actions) - 1]


def qlearning(num_iter1, alpha, epsilon):
    x = []
    y = []
    qfunc = dict()  # 行为值函数为字典
    # 初始化行为值函数为0
    for s in states:
        for a in actions:
            key = "%d_%s" % (s, a)
            qfunc[key] = 0.0
    for iter1 in range(num_iter1):
        x.append(iter1)
        y.append(compute_error(qfunc))

        # 初始化初始状态
        s = grid.reset()
        a = actions[int(random.random() * len(actions))]
        t = False
        count = 0
        while t is False and count < 100:
            key = "%d_%s" % (s, a)
            # 与环境进行一次交互,从环境中得到新的状态及回报
            s1, r, t, _, _ = grid.step(a)
            # s1处的最大动作
            a1 = greedy(qfunc, s1)
            key1 = "%d_%s" % (s1, a1)
            # 利用qlearning方法更新值函数
            qfunc[key] = qfunc[key] + alpha * (r + gamma * qfunc[key1] - qfunc[key])
            # 转到下一个状态
            s = s1
            a = epsilon_greedy(qfunc, s1, epsilon)
            count += 1
    plt.plot(x, y, "-.,", label="q alpha=%2.1f epsilon=%2.1f" % (alpha, epsilon))
    return qfunc

三、训练及测试代码 

训练及测试的所有代码放在learning_and_test.py中,代码如下,其中num_iter1是训练次数,请随迷宫的尺寸变化自行调整其大小,以使误差函数收敛:

from qlearning import *
import time

# main函数
if __name__ == "__main__":

    sleeptime = 0.5
    terminate_states = grid.env.getTerminate_states()
    # 读入最优值函数
    read_best()
    # 训练
    qfunc = qlearning(num_iter1=4000, alpha=0.2, epsilon=0.2)
    # 画图
    plt.xlabel("number of iterations")
    plt.ylabel("square errors")
    plt.legend()
    # 显示误差图像
    plt.show()
    time.sleep(sleeptime)
    # 学到的值函数
    for s in states:
        for a in actions:
            key = "%d_%s" % (s, a)
            print("the qfunc of key (%s) is %f" % (key, qfunc[key]))
            # qfunc[key]
    # 学到的策略为:
    print("the learned policy is:")
    for i in range(len(states)):
        if states[i] in terminate_states:
            print("the state %d is terminate_states" % (states[i]))
        else:
            print("the policy of state %d is (%s)" % (states[i], greedy(qfunc, states[i])))
    # 设置系统初始状态
    s0 = 1
    grid.env.setAction(s0)
    # 对训练好的策略进行测试
    # 随机初始化,寻找金币的路径
    for i in range(20):
        # 随机初始化
        s0 = grid.reset()
        grid.render()
        time.sleep(sleeptime)
        t = False
        count = 0
        # 判断随机状态是否在终止状态中
        if s0 in terminate_states:
            print("reach the terminate state %d" % s0)
        else:
            while t is False and count < 100:
                a1 = greedy(qfunc, s0)
                print(s0, a1)
                grid.render()
                time.sleep(sleeptime)
                key = "%d_%s" % (s0, a1)
                # 与环境进行一次交互,从环境中得到新的状态及回报
                s1, r, t, _, i = grid.step(a1)
                if t is True:
                    # 打印终止状态
                    print(s1)
                    grid.render()
                    time.sleep(sleeptime)
                    print("reach the terminate state %d" % s1)
                # s1处的最大动作
                s0 = s1
                count += 1
    grid.close()

请调整20-22行代码以创建你想要的迷宫:

将learning_and_test.py和qlearning.py放在同一个文件夹下,即可运行 learning_and_test.py进行训练及测试。代码运行过程中会有一系列的logger.warn(报错,不用管。

四、需要修改的点

该代码的损失收敛曲线计算方式并不是很棒,不应该使用best_qfunc来计算,需要再做修改

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值