简介
gym的核心接口是Env,作为统一的环境接口。Env包含下面几个核心方法:
- reset(self):重置环境的状态,返回观察。
- step(self, action):推进一个时间步长,返回observation,reward,done,info
- render(self, mode=’human’, close=False):重绘环境的一帧。默认模式一般比较友好,如弹出一个窗口。
render
用 gym 搭建这个简单的环境
绘制可视化环境
import gym
from gym.envs.classic_control import rendering # render 函数里要有这个包,否则报错
class GridEnv(gym.Env):
def __init__(self):
# 600*400 的窗口
self.viewer = rendering.Viewer(600,400)
def render(self,mode='human'):
# 画网格
self.viewer.draw_line((100,300),(500,300))
self.viewer.draw_line((100,200),(500,200))
self.viewer.draw_line((100,300),(100,100))
self.viewer.draw_line((180,300),(180,100))
self.viewer.draw_line((260,300),(260,100))
self.viewer.draw_line((340,300),(340,100))
self.viewer.draw_line((420,300),(420,100))
self.viewer.draw_line((500,300),(500,100))
self.viewer.draw_line((100,100),(180,100))
self.viewer.draw_line((260,100),(340,100))
self.viewer.draw_line((420,100),(500,100))
# 画金币 添加平移这个属性
self.viewer.draw_circle(40,color=(1,0.9,0)).add_attr(rendering.Transform(translation=(300,150)))
# 画陷阱
self.viewer.draw_circle(40,color=(0,0,0)).add_attr(rendering.Transform(translation=(140,150)))
self.viewer.draw_circle(40,color=(0,0,0)).add_attr(rendering.Transform(translation=(460,150)))
# 机器人
self.viewer.draw_circle(30,color=(0.8,0.6,0.4)).add_attr(rendering.Transform(translation=(140,250)))
return self.viewer.render(return_rgb_array=mode == 'human')
# 测试环境
env = GridEnv()
while True:
env.render()
step
step()函数的输入是动作,输出是下一个时刻的动作、回报、是否终止和调试信息。对于调试信息,可以为空,但不能缺少,否则会报错,常用{}来代替。
有了状态空间、动作空间和状态转移概率,我们便可以写 step(a) 函数了。
状态空间:
self.states = [1,2,3,4,5,6,7,8]
动作空间:
self.actions = ['n','e','s','w']
回报函数:
self.rewards = dict() # 回报函数的数据结构为字典
self.rewards['1_s'] = -1.0
self.rewards['3_s'] = 1.0
self.rewards['5_s'] = -1.0
状态转移概率:
self.t = dict() # 状态转移的数据结构为字典
self.t['1_s'] = 6
self.t['1_e'] = 2
self.t['2_w'] = 1
self.t['2_e'] = 3
self.t['3_w'] = 2
self.t['3_e'] = 4
self.t['3_s'] = 7
self.t['4_w'] = 3
self.t['4_e'] = 5
self.t['5_w'] = 4
self.t['5_s'] = 8
def step(self,action):
# 系统当前状态
state = self.state
# 判断系统当前状态是否为终止状态
if state in self.terminate_states:
return state,0,True,{}
key = "%d_%s"%(state,action) # 将状态和动作组成字典的键值对
# 状态转移
if key in self.t:
next_state = self.t[key]
else:
next_state = state
self.state = next_state
is_terminal = False
if next_state in self.terminate_states:
is_terminal = True
if key not in self.rewards:
r = 0.0
else:
r = self.rewards[key]
return next_state,r,is_terminal,{}
reset
def reset(self):
self.state = self.states[int(random.random() * (len(self.states)-3))]
return self.state
完整代码
import gym
from numpy import random
import time
class GridEnv(gym.Env):
def __init__(self):
self.states = [1,2,3,4,5,6,7,8] # 状态
self.x=[140,220,300,380,460,140,300,460] # 机器人像素位置
self.y=[250,250,250,250,250,150,150,150]
self.actions = ['n','e','s','w'] # 动作
self.rewards = dict() # 回报函数的数据结构为字典
self.rewards['1_s'] = -1.0
self.rewards['3_s'] = 1.0
self.rewards['5_s'] = -1.0
self.t = dict() # 状态转移的数据结构为字典
self.t['1_s'] = 6
self.t['1_e'] = 2
self.t['2_w'] = 1
self.t['2_e'] = 3
self.t['3_w'] = 2
self.t['3_e'] = 4
self.t['3_s'] = 7
self.t['4_w'] = 3
self.t['4_e'] = 5
self.t['5_w'] = 4
self.t['5_s'] = 8
self.terminate_states = [6,7,8]
self.viewer = None
def step(self,action):
# 系统当前状态
state = self.state
# 判断系统当前状态是否为终止状态
if state in self.terminate_states:
return state,0,True,{}
key = "%d_%s"%(state,action) # 将状态和动作组成字典的键值对
# 状态转移
if key in self.t:
next_state = self.t[key]
else:
next_state = state
self.state = next_state
is_terminal = False
if next_state in self.terminate_states:
is_terminal = True
if key not in self.rewards:
r = 0.0
else:
r = self.rewards[key]
return next_state,r,is_terminal,{}
def reset(self):
self.state = self.states[int(random.random() * (len(self.states)-3))]
return self.state
def close(self):
if self.viewer:
self.viewer.close()
self.viewer = None
def render(self,mode='human'):
from gym.envs.classic_control import rendering # 这里一定要导入这个包,否则报错
if self.viewer is None:
self.viewer = rendering.Viewer(600,400)
# 画网格
self.viewer.draw_line((100,300),(500,300))
self.viewer.draw_line((100,200),(500,200))
self.viewer.draw_line((100,300),(100,100))
self.viewer.draw_line((180,300),(180,100))
self.viewer.draw_line((260,300),(260,100))
self.viewer.draw_line((340,300),(340,100))
self.viewer.draw_line((420,300),(420,100))
self.viewer.draw_line((500,300),(500,100))
self.viewer.draw_line((100,100),(180,100))
self.viewer.draw_line((260,100),(340,100))
self.viewer.draw_line((420,100),(500,100))
# 画金币
self.viewer.draw_circle(40,color=(1,0.9,0)).add_attr(rendering.Transform(translation=(300,150)))
# 画陷阱
self.viewer.draw_circle(40,color=(0,0,0)).add_attr(rendering.Transform(translation=(140,150)))
self.viewer.draw_circle(40,color=(0,0,0)).add_attr(rendering.Transform(translation=(460,150)))
# 机器人
#self.viewer.draw_circle(30,color=(0.8,0.6,0.4)).add_attr(rendering.Transform(translation=(140,250)))
self.rebottrans = rendering.Transform()
self.rebot = self.viewer.draw_circle(30,color=(0.8,0.6,0.4)).add_attr(self.rebottrans)
if self.state is None: return None
#self.rebot = self.viewer.draw_circle(30,color=(0.8,0.6,0.4)).add_attr(rendering.Transform(translation=(self.x[self.state-1],self.y[self.state-1])))
self.rebottrans.set_translation(self.x[self.state-1],self.y[self.state-1])
return self.viewer.render(return_rgb_array=mode == 'rgb_array')
# 测试环境
if __name__ == "__main__":
env = GridEnv()
env.reset()
while True:
action = env.actions[int(random.random()*len(env.actions))]
next_state,r,is_terminal,info = env.step(action)
env.render()
if is_terminal == True:
print("reward:",r)
break
time.sleep(0.5)
env.close()