我们为什么要创建一个gym的环境呢?因为需要,哈哈哈,这是一句废话,但是也是一句真话。因为我不想自己写强化学习的算法了,我想用一些现成的框架,这些框架训练的都是gym的游戏,那我把我自己想要训练的东西改成一个gym的框架,不就可以直接用强化学习的框架来训练了嘛。就是这么一个简单的需求,我们开始吧。
顺便说一下,我这里的gym是比较老的版本,新版本的gym有一些不同,新版本的gym在step
函数中返回的是一个五元组,reset
返回的是一个二元组,这都与旧版本不同,但是这些强化学习的框架还没有改过来,所以我们也就使用了旧版本函数,不过这也不是什么大问题。
需要实现的函数
__init__()
构造函数中我们需要定义两个变量self.action_space
和self.observation_space
, 为了覆盖父类的变量,这两个变量的名字是固定的。它们定义了强化学习中的动作空间和状态空间的类型和大小,如果是离散的,则使用Discrete
创建,参数为离散量的个数,比如CartPole中,CartPole中的 self.action_space
实际就是用Discrete(2)
创建的,如果是连续的,则使用Box创建,比如CartPole这个例子中,状态有四维,而且状态空间的每个维度都有定义域,那么就可以如下创建:
self.action_space = Discrete(2)
high = np.array([
self.x_threshold * 2,
np.finfo(np.float32).max, # finfo可以显示响应类型的机器限制,这里为浮点数最大值
self.theta_threshold_radians * 2,
np.finfo(np.float32).max,
])
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
reset()
调用这个方法可以重置模拟器环境,并返回重启后的模拟器中agent的初始state
step()
调用该方法以实现agent与simulator进行一次交互。我们的奖励机制也需要写在这个函数中,所以该函数非常重要。该函数的返回值必须是四元组,包含:
state 状态,也就是状态空间
reward 奖励,交互后agent得到的奖励
done 结束,true表示已经结束,False表示没有结束
info 信息,是一个字典,用来debug,一般用不到
render()
用来显示画面,一般pass,如果有能力写动画的话也可以
seed()
用来设置随机种子,一般pass,如果程序中有一些随机性的行为,可以在这里设置随机种子。
定义一个简单的环境
我们定义我们的动作空间为两个值,范围均为 [-1,1]
,状态空间或者说观测空间为离散的5个变量,奖励为两个动作的和,状态转移均为从0到1234,状态到4就结束。看一下代码
import gym
from gym import spaces
from stable_baselines3 import A2C
import numpy as np
class MySim(gym.Env):
def __init__(self):
low = np.array([-1,-1],dtype=np.float32)
high = np.array([1,1],dtype=np.float32)
self.action_space = spaces.Box(low,high,dtype=np.float32)
self.observation_space = spaces.Discrete(5)
self.state = 0
def step(self,action):
self.state += 1
reward = action[0] + action[1] # 这里的reward的类型是np,float32,不是python内建的float
done = False
if self.state == 4:
done = True
info = {}
return self.state, float(reward), done, info
def reset(self):
self.state = 0
return 0
def render(self,mode="human"):
pass
def seed():
pass
可以用一些工具来检测这个环境是否正确,比如 stable_baselines3
下的 check_env
,就可以帮助我们进行环境的检测
from stable_baselines3.common.env_checker import check_env
if __name__ == '__main__':
env = MySim()
check_env(env)
如果没有任何报错就说明环境正常,可以用这个环境来训练一些代码了,我们用了A2C算法
if __name__ == '__main__':
env = MySim()
check_env(env)
model = A2C(policy="MlpPolicy", env=env)
model.learn(total_timesteps=10000)
obs = env.reset()
# 验证一次
for _ in range(10):
action, state = model.predict(observation=obs)
print(action[0] + action[1])
obs, reward, done, info = env.step(action)
if done:
break
输出为
2.0
2.0
2.0
2.0
可以预见的输出正确。