本文实现了一个简单的基于gym环境的强化学习的demo,参考了博客使用gym创建一个自定义环境。
1. 依赖包版本
gym == 0.21.0
stable-baselines3 == 1.6.2
2. 场景描述
起点:(0,0)
终点:(4,4)
动作空间:{0:向上,1:向下,2:向左,3:向右}
状态空间:agent所处坐标
目标:以最短的路径起点走到终点
奖励设置:到达终点奖励为:10,其他每走一步奖励为:-1
终止条件:1. 达到终点;2.探索次数超过200
3. 搭建gym环境
"""
@Author: Fhz
@Create Date: 2023/4/6 22:08
@File: gym_test.py
@Description:
@Modify Person Date:
"""
import gym
from gym import Env
from gym import spaces
import numpy as np
from copy import deepcopy
class PathPlanning(Env):
def __init__(self):
self.rows = 5
self.cols = 5
self.start = [0, 0]
self.goal = [4, 4]
self.count = 0
self.current_state = None
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(low=np.array([0, 0]), high=np.array([4, 4]))
def step(self, action):
self.count = self.count + 1
new_state = deepcopy(self.current_state)
if action == 0: # up
new_state[0] = max(new_state[0] - 1, 0)
elif action == 1: # down
new_state[0] = min(new_state[0] + 1, self.cols - 1)
elif action == 2: # left
new_state[1] = max(new_state[1] - 1, 0)
elif action == 3: # right
new_state[1] = min(new_state[1] + 1, self.rows - 1)
else:
raise Exception("Invalid action")
self.current_state = new_state
if self.current_state[1] == self.goal[1] and self.current_state[0] == self.goal[0]:
done = True
reward = 10.0
else:
done = False
reward = -1
if self.count > 200:
done = True
info = {}
return self.current_state, reward, done, info
def render(self):
pass
def reset(self):
self.count = 0
self.current_state = self.start
return self.current_state
4. Baseline强化学习寻优
4.1 参数设置
网络框架:“mlp”, 64x64
学习率:5e-4
batch size:32
训练次数:5e4
"""
@Author: Fhz
@Create Date: 2023/4/6 22:08
@File: gym_test.py
@Description:
@Modify Person Date:
"""
import gym
from gym import Env
from gym import spaces
import numpy as np
from copy import deepcopy
from stable_baselines3 import PPO
class PathPlanning(Env):
def __init__(self):
self.rows = 5
self.cols = 5
self.start = [0, 0]
self.goal = [4, 4]
self.count = 0
self.current_state = None
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(low=np.array([0, 0]), high=np.array([4, 4]))
def step(self, action):
self.count = self.count + 1
new_state = deepcopy(self.current_state)
if action == 0: # up
new_state[0] = max(new_state[0] - 1, 0)
elif action == 1: # down
new_state[0] = min(new_state[0] + 1, self.cols - 1)
elif action == 2: # left
new_state[1] = max(new_state[1] - 1, 0)
elif action == 3: # right
new_state[1] = min(new_state[1] + 1, self.rows - 1)
else:
raise Exception("Invalid action")
self.current_state = new_state
if self.current_state[1] == self.goal[1] and self.current_state[0] == self.goal[0]:
done = True
reward = 10.0
else:
done = False
reward = -1
if self.count > 200:
done = True
info = {}
return self.current_state, reward, done, info
def render(self):
pass
def reset(self):
self.count = 0
self.current_state = self.start
return self.current_state
if __name__ == '__main__':
env = PathPlanning()
model = PPO('MlpPolicy', env,
policy_kwargs=dict(net_arch=[64, 64]),
learning_rate=5e-4,
batch_size=32,
gamma=0.8,
verbose=1,
tensorboard_log="PPO_define/")
model.learn(int(5e4))
model.save("PPO_define/PPOmodel")
4.2 训练结果
4.3 结果测试
if __name__ == '__main__':
env = PathPlanning()
ACTIONS_ALL = {
0: 'Up', # 向上
1: 'Down', # 向下
2: 'Left', # 向左
3: 'Right' # 向右
}
# load model
model = PPO.load("PPO_define/PPOmodel", env=env)
eposides = 10
for eq in range(eposides):
obs = env.reset()
done = False
rewards = 0
while not done:
# action = env.action_space.sample()
action, _state = model.predict(obs, deterministic=True)
action = action.item()
print("The action is: {}".format(ACTIONS_ALL[action]))
# print("The action is {}".format(action))
obs, reward, done, info = env.step(action)
env.render()
rewards += reward
print(rewards)
结果输出:
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0
The action is: Down
The action is: Down
The action is: Right
The action is: Right
The action is: Down
The action is: Right
The action is: Down
The action is: Right
3.0