【强化学习/tf/gym】(一)创建自定义gym环境

说在前面

  • 环境:Windows10
  • python版本:3.6
  • gym版本:0.18.3
  • 代码:github

目标

  • 本文将使用gym自定义一个简单的环境。如下所示:
    在这里插入图片描述
  • 其中蓝色小球为其它球,紫色小球agent/玩家控制的球。
  • 蓝色小球只会往一个方向移动,紫色小球可以往任意方向移动。
  • 所有球移动到边界外后都会从另一边进入。
  • 球的移动速度受球本身大小影响,球越大,移动越慢
  • 所有球之间都可以相互吞噬,但只能大球吃小球
  • 游戏目标是控制紫色小球吞噬其它球。

准备工作

开始

gym env

  • gym将环境抽象成一个类env,我们自定义的环境只需要继承该类并实现对应的接口即可
    # 源码地址 https://github.com/openai/gym/blob/master/gym/core.py
    class Env(object):
        """我们要关注的主要方法是:
            step
            reset
            render
            close
            seed
    
        还有以下的一些属性:
            action_space: 行为空间,在本例中,行为空间即[0,360),整型,因为我们支持360度的方向改变;
            			  行为空间可能不止一个维度,例如在某些支持多个行为的环境中,行为空间可以是[[0, 360], [0, 100], ...]
            observation_space: 观察空间,agent能看到的数据,可以是环境的部分或全部数据
            				   在本例中(为全量数据),我们将其设置为所有球的坐标、分数、以及球类型(自己还是其它球)
            reward_range: 奖励范围,agent执行一个动作后,环境给予的奖励值,在本例中为实数范围
        """
    
        def step(self, action):
            """环境的主要驱动函数,主逻辑将在该函数中实现。该函数可以按照时间轴,固定时间间隔调用
    
            参数:
                action (object): an action provided by the agent
    
            返回值:
                observation (object): agent对环境的观察,在本例中,直接返回环境的所有状态数据
                reward (float) : 奖励值,agent执行行为后环境反馈
                done (bool): 该局游戏时候结束,在本例中,只要自己被吃,该局结束
                info (dict): 函数返回的一些额外信息,可用于调试等
            """
            raise NotImplementedError
    
        def reset(self):
            """将环境重置为初始状态,并返回一个初始状态;在环境中有随机性的时候,需要注意每次重置后要保证与之前的环境相互独立
            """
            raise NotImplementedError
    
        def render(self, mode='human'):
            """环境渲染,用于将环境用图形、数字等表现出来
            """
            raise NotImplementedError
    
        def close(self):
            """一些环境数据的释放可以在该函数中实现
            """
            pass
    
        def seed(self, seed=None):
            """设置环境的随机生成器
            """
            return
    

action space

  • 如上文所述,本例使用的行为空间为0~359的整型数据,在gym中可以这样表达:
    self.action_space = spaces.Discrete(360)
    
  • gym.spaces中定义了一些常用的空间类,Discrete为其中之一,它表示的是 [ 0 , n − 1 ] [0,n-1] [0,n1]的离散空间,即 0 , 1 , 2 , 3 , . . . , n − 1 0,1,2,3,...,n-1 0,1,2,3,...,n1

observation space

  • 如上文所述,本例使用的观察空间为所有球的坐标、分数、以及球类型(自己还是其它球)。在gym中可以这样表达:
    self.observation_space = spaces.Box(low=0, high=VIEWPORT_H, shape=(MAX_BALL_NUM, 4), dtype=np.float32)
    
  • Box定义了一个 R n R^n Rn维度的空间。一般有两种表达方式:
    # 1
    spaces.Box(low=, high=, shape=(,...), dtype=)
    
    上述这种使用lowhigh来约束空间中每一个值得范围,当然也可以是无穷;shape定义了这个空间的维度,类似于numpy中的shapedtype定义了空间中每个值的类型。例如,RGB图像空间可以这样表示:
    s = spaces.Box(low=0, high=255, shape=(4, 5, 3), dtype=np.uint8)
    print(s.sample()) # 随机取样 5x4 rgb
    '''
    [[[ 81 163  74] [175 178 191] [199  89 210] [145 202  84] [116 252 115]]
     [[224 168  55] [236  98  21] [224 129 164] [150  64 204] [ 56  98 120]]
     [[ 78 179  96] [ 83 162 247] [159  48 184] [172 188 114] [ 61  68 147]]
     [[122 203 119] [ 80 237 171] [ 69 212 219] [ 65  62  62] [189 185 167]]]
    '''
    
    # 2
    Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32)
    
    上述这种方式使用lowhigh来约束空间中不同位置的值,其shape取决于lowhighshape,例如:
    s = spaces.Box(low=np.array([[-1.0, -2.0], [-1.0, -2.0]]), high=np.array([[2.0, 4.0], [2.0, 4.0]]), dtype=np.float32)
    print(s.sample())
    ```
    [[0.7104309  2.0792136 ]
     [1.8466672  0.08392456]]
    ```
    

reset

  • 使用reset对环境进行重置,通常我们在初始化env的时候就会reset一次
  • 在本例中,我们的重置逻辑很简单,对每个球的坐标、分数进行随机,同时选择一个作为agent控制的球即可
    def reset(self):
    	# 管理所有球的列表, reset时先清空
        self.balls = []
    
        # 随机生成MAX_BALL_NUM - 1个其它球
        for i in range(MAX_BALL_NUM - 1):
            self.balls.append(self.randball(BALL_TYPE_OTHER))
    
        # 生成agent球
        self.selfball = self.randball(BALL_TYPE_SELF)
    
        # 把agent球加入管理列表
        self.balls.append(self.selfball)
    
        # 更新观察数据
        self.state = np.vstack([ball.state() for (_, ball) in enumerate(self.balls)])
    
    	# 返回
        return self.state
    
    @staticmethod
    def randball(_t: np.int):
    	# _t 为球类型(agent或其它)
    	# Ball class 参数为坐标x,y, 分数score, 类型_t
    	# VIEWPORT_W, VIEWPORT_H为地图宽高
        _b = Ball(np.random.rand(1)[0]*VIEWPORT_W, np.random.rand(1)[0]*VIEWPORT_H, np.random.rand(1)[0] * MAX_BALL_SCORE, np.int(np.random.rand(1)[0] * 360), _t)
        return _b
    
    

step

  • 在该函数中我们需要处理主逻辑,也就是:
    • 更新球的位置
    • 处理球之间的吞噬
    • 数据统计,例如reward
  • 让我们来看看代码,首先我们定义了一个Ball(代码比较简单,注释也有,就不展开了)
    class Ball():
        def __init__(self, x: np.float32, y: np.float32, score: np.float32, way: np.int, t: np.int):
            '''
                x   初始x坐标
                y   初始y坐标
                s   初始分
                w	移动方向,弧度值
                t   球类型
            '''
            self.x = x
            self.y = y
            self.s = score
            self.w = way * 2 * math.pi / 360.0  # 角度转弧度
            self.t = t
    
            self.id = GenerateBallID()      # 生成球唯一id
            self.lastupdate = time.time()   # 上一次的计算时间
            self.timescale = 100            # 时间缩放,或者速度的缩放
    
        def update(self, way):
            '''
                更新球的状态
            '''
    
            # 如果是agent球,那么就改变方向
            if self.t == BALL_TYPE_SELF:
                self.w = way * 2 * math.pi / 360.0  # 角度转弧度
    
            speed = 1.0 / self.s    # 分数转速度大小
            now = time.time()       # 当前时间值
    
            self.x += math.cos(self.w) * speed * (now - self.lastupdate) * self.timescale   # 距离=速度*时间
            self.y += math.sin(self.w) * speed * (now - self.lastupdate) * self.timescale   
    
            self.x = CheckBound(0, VIEWPORT_W, self.x)
            self.y = CheckBound(0, VIEWPORT_H, self.y)
    
            self.lastupdate = now   # 更新计算时间
    
        def addscore(self, score: np.float32):
            self.s += score
    
        def state(self):
            return [self.x, self.y, self.s, self.t]
    
  • 然后看看step函数
    def step(self, action):
        reward = 0.0	# 奖励初始值为0
        done = False	# 该局游戏是否结束
    
        # 首先调用ball.update方法更新球的状态
        for _, ball in enumerate(self.balls):
            ball.update(action)
    
        # 然后处理球之间的吞噬
        # 定一个要补充的球的类型列表,吃了多少球,就要补充多少球
        _new_ball_types = []
        # 遍历,这里就没有考虑性能问题了
        for _, A_ball in enumerate(self.balls):
            for _, B_ball in enumerate(self.balls):
    
    			# 自己,跳过
                if A_ball.id == B_ball.id:
                    continue
    
                # 先计算球A的半径
                # 我们使用球的分数作为球的面积
                A_radius = math.sqrt(A_ball.s / math.pi)
    
                # 计算球AB之间在x\y轴上的距离
                AB_x = math.fabs(A_ball.x - B_ball.x)
                AB_y = math.fabs(A_ball.y - B_ball.y)
    
                # 如果AB之间在x\y轴上的距离 大于 A的半径,那么B一定在A外
                if AB_x > A_radius or AB_y > A_radius:
                    continue
    
                # 计算距离
                if AB_x*AB_x + AB_y*AB_y > A_radius*A_radius:
                    continue
    
                # 如果agent球被吃掉,游戏结束
                if B_ball.t == BALL_TYPE_SELF:
                    done = True
    
                # A吃掉B A加上B的分数
                A_ball.addscore(B_ball.s)
    
                # 计算奖励
                if A_ball.t == BALL_TYPE_SELF:
                    reward += B_ball.s
    
                # 把B从列表中删除,并记录要增加一个B类型的球
                _new_ball_types.append(B_ball.t)
                self.balls.remove(B_ball)
    
        # 补充球
        for _, val in enumerate(_new_ball_types):
            self.balls.append(self.randball(np.int(val)))
    
    	# 填充观察数据
        self.state = np.vstack([ball.state() for (_, ball) in enumerate(self.balls)])
    
    	# 返回
        return self.state, reward, done, {}
    

render

  • render是一个比较重要的函数,如果你有可视化需求的话。
  • 用于渲染的方式并没有特别要求,你可以使用你习惯的方式来渲染数据,比如opencvopengl,甚至是matplotlib
  • 在本例中,我们使用的gym所使用的方式pyglet,并且gym对其进行了一定程度的封装,使用起来还是比较方便的
    def render(self, mode='human'):
        # 按照gym的方式创建一个viewer, 使用self.scale控制缩放大小
        from gym.envs.classic_control import rendering
        if self.viewer is None:
            self.viewer = rendering.Viewer(VIEWPORT_W * self.scale, VIEWPORT_H * self.scale)
    
        # 渲染所有的球
        for item in self.state:
        	# 从状态中获取坐标、分数、类型
            _x, _y, _s, _t = item[0] * self.scale, item[1] * self.scale, item[2], item[3]
    
    		# transform用于控制物体位置、缩放等
            transform = rendering.Transform()
            transform.set_translation(_x, _y)
    
            # 添加一个⚪,来表示球
            # 中心点: (x, y)
            # 半径: sqrt(score/pi)
            # 颜色: 其它球为蓝色、agent球为红/紫色
            self.viewer.draw_circle(math.sqrt(_s / math.pi) * self.scale, 30, color=(_t, 0, 1)).add_attr(transform)
    
    	# 然后直接渲染(没有考虑性能)
        return self.viewer.render(return_rgb_array = mode=='rgb_array')
    

使用

  • 一种方式是按照官方推荐那样,通过pip install来进行安装,然后用gym.make来创建
    • 下载github代码
    • 解压后将文件夹名改为gym-ball
    • 然后pip install -e gym-ball
    • 最后使用下述代码测试
      import gym
      
      env = gym.make('gym_ball:ball-v0')
      
      while True:
          env.step(150)
          env.render()
      
  • 另一种方法是直接创建BallEnv类实例即可
    class BallEnv(gym.Env):
        metadata = {'render.modes': ['human']}
    
        def __init__(self):
    
        def reset(self):
    
        def step(self, action):
    
        def render(self, mode='human'):
    
        def close(self):
    
    if __name__ == '__main__':
        env = BallEnv()
        
        while True:
            env.step(150)
            env.render()
    

在这里插入图片描述

  • 38
    点赞
  • 179
    收藏
    觉得还不错? 一键收藏
  • 18
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值