初识PARL实现强化学习

初识PARL实现强化学习

一、什么是强化学习

  • 强化学习(英语:Reinforcement learning,简称RL)是机器学习中的一个领域,强调如何基于环境而行动,以取得最大化的预期利益。
  • 核心思想:智能体agent在环境environment中学习,根据环境的状态state(或观测到的observation),执行动作action,并根据环境的反馈 reward(奖励)来指导更好的动作。

注意:从环境中获取的状态,有时候叫state,有时候叫observation,这两个其实一个代表全局状态,一个代表局部观测值,在多智能体环境里会有差别,但我们刚开始学习遇到的环境还没有那么复杂,可以先把这两个概念划上等号。

强化学习模型

二、能做什么

  • 游戏(马里奥、Atari、Alpha Go、星际争霸等)
  • 机器人控制(机械臂、机器人、自动驾驶、四轴飞行器等)
  • 用户交互(推荐、广告、NLP等)
  • 交通(拥堵管理等)
  • 资源调度(物流、带宽、功率等)
  • 金融(投资组合、股票买卖等)
  • 其他

三、 强化学习与监督学习的区别

  • 强化学习、监督学习、非监督学习是机器学习里的三个不同的领域,都跟深度学习有交集。
  • 监督学习寻找输入到输出之间的映射,比如分类和回归问题。
  • 非监督学习主要寻找数据之间的隐藏关系,比如聚类问题。
  • 强化学习则需要在与环境的交互中学习和寻找最佳决策方案。
  • 监督学习处理认知问题,强化学习处理决策问题。

四、 强化学习的如何解决问题

  • 强化学习通过不断的试错探索,吸取经验和教训,持续不断的优化策略,从环境中拿到更好的反馈。
  • 强化学习有两种学习方案:基于价值(value-based)、基于策略(policy-based)

五、 强化学习的算法和环境

  • 经典算法:Q-learningSarsaDQNPolicy GradientA3CDDPGPPO
  • 环境分类:离散控制场景(输出动作可数)、连续控制场景(输出动作值不可数)
  • 强化学习经典环境库GYM将环境交互接口规范化为:重置环境reset()、交互step()、渲染render()
  • 强化学习框架库PARL将强化学习框架抽象为ModelAlgorithmAgent三层,使得强化学习算法的实现和调试更方便和灵活。

image-20200616211625249

六、小案例

先给出Git地址https://github.com/PaddlePaddle/PARL.git

在这里面就可以看到各种算法和一些小案例了

这是一个小乌龟要从左下角走到右下角黄色区域的案例。

# -*- coding: utf-8 -*-

import gym
import turtle
import numpy as np

# turtle tutorial : https://docs.python.org/3.3/library/turtle.html


def GridWorld(gridmap=None, is_slippery=False):
    if gridmap is None:
        gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG']
    env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False)
    env = FrozenLakeWapper(env)
    return env


class FrozenLakeWapper(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.max_y = env.desc.shape[0]
        self.max_x = env.desc.shape[1]
        self.t = None
        self.unit = 50

    def draw_box(self, x, y, fillcolor='', line_color='gray'):
        self.t.up()
        self.t.goto(x * self.unit, y * self.unit)
        self.t.color(line_color)
        self.t.fillcolor(fillcolor)
        self.t.setheading(90)
        self.t.down()
        self.t.begin_fill()
        for _ in range(4):
            self.t.forward(self.unit)
            self.t.right(90)
        self.t.end_fill()

    def move_player(self, x, y):
        self.t.up()
        self.t.setheading(90)
        self.t.fillcolor('red')
        self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)

    def render(self):
        if self.t == None:
            self.t = turtle.Turtle()
            self.wn = turtle.Screen()
            self.wn.setup(self.unit * self.max_x + 100,
                          self.unit * self.max_y + 100)
            self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
                                        self.unit * self.max_y)
            self.t.shape('circle')
            self.t.width(2)
            self.t.speed(0)
            self.t.color('gray')
            for i in range(self.desc.shape[0]):
                for j in range(self.desc.shape[1]):
                    x = j
                    y = self.max_y - 1 - i
                    if self.desc[i][j] == b'S':  # Start
                        self.draw_box(x, y, 'white')
                    elif self.desc[i][j] == b'F':  # Frozen ice
                        self.draw_box(x, y, 'white')
                    elif self.desc[i][j] == b'G':  # Goal
                        self.draw_box(x, y, 'yellow')
                    elif self.desc[i][j] == b'H':  # Hole
                        self.draw_box(x, y, 'black')
                    else:
                        self.draw_box(x, y, 'white')
            self.t.shape('turtle')

        x_pos = self.s % self.max_x
        y_pos = self.max_y - 1 - int(self.s / self.max_x)
        self.move_player(x_pos, y_pos)


class CliffWalkingWapper(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.t = None
        self.unit = 50
        self.max_x = 12
        self.max_y = 4

    def draw_x_line(self, y, x0, x1, color='gray'):
        assert x1 > x0
        self.t.color(color)
        self.t.setheading(0)
        self.t.up()
        self.t.goto(x0, y)
        self.t.down()
        self.t.forward(x1 - x0)

    def draw_y_line(self, x, y0, y1, color='gray'):
        assert y1 > y0
        self.t.color(color)
        self.t.setheading(90)
        self.t.up()
        self.t.goto(x, y0)
        self.t.down()
        self.t.forward(y1 - y0)

    def draw_box(self, x, y, fillcolor='', line_color='gray'):
        self.t.up()
        self.t.goto(x * self.unit, y * self.unit)
        self.t.color(line_color)
        self.t.fillcolor(fillcolor)
        self.t.setheading(90)
        self.t.down()
        self.t.begin_fill()
        for i in range(4):
            self.t.forward(self.unit)
            self.t.right(90)
        self.t.end_fill()

    def move_player(self, x, y):
        self.t.up()
        self.t.setheading(90)
        self.t.fillcolor('red')
        self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)

    def render(self):
        if self.t == None:
            self.t = turtle.Turtle()
            self.wn = turtle.Screen()
            self.wn.setup(self.unit * self.max_x + 100,
                          self.unit * self.max_y + 100)
            self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
                                        self.unit * self.max_y)
            self.t.shape('circle')
            self.t.width(2)
            self.t.speed(0)
            self.t.color('gray')
            for _ in range(2):
                self.t.forward(self.max_x * self.unit)
                self.t.left(90)
                self.t.forward(self.max_y * self.unit)
                self.t.left(90)
            for i in range(1, self.max_y):
                self.draw_x_line(
                    y=i * self.unit, x0=0, x1=self.max_x * self.unit)
            for i in range(1, self.max_x):
                self.draw_y_line(
                    x=i * self.unit, y0=0, y1=self.max_y * self.unit)

            for i in range(1, self.max_x - 1):
                self.draw_box(i, 0, 'black')
            self.draw_box(self.max_x - 1, 0, 'yellow')
            self.t.shape('turtle')

        x_pos = self.s % self.max_x
        y_pos = self.max_y - 1 - int(self.s / self.max_x)
        self.move_player(x_pos, y_pos)


if __name__ == '__main__':

    env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left
    env = CliffWalkingWapper(env)

    env.reset()
    # for step in range(100):
    #     action = np.random.randint(0, 4)
    #     obs, reward, done, info = env.step(action)
    #     print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\
    #             step, action, obs, reward, done, info))
    #     env.render()  # 渲染一帧图像

    while True:
        action = np.random.randint(0, 4)
        obs, reward, done, info = env.step(action)
        print('action {}, obs {}, reward {}, done {}, info {}'.format( \
            action, obs, reward, done, info))
        env.render()  # 渲染一帧图像
        if done:
            break

main函数之外的我们可以先不管,关于一些类的定义,先来认识一下最主要的一些结构

env = gym.make("CliffWalking-v0")创建一个悬崖环境

env = CliffWalkingWapper(env)准备可视化,这一句可有可无,加上的话显示的会更好看

env.reset()设置初始环境

action = np.random.randint(0, 4)
obs, reward, done, info = env.step(action)

随机选取一个方向,并进行移动

env.render() # 渲染一帧图像进行每次操作后的画面渲染

七、cart pole平衡问题

Cart pole的玩法如图所示,目标就是保持一根杆一直竖直朝上,杆由于重力原因会一直倾斜,当杆倾斜到一定程度就会倒下,此时需要朝左或者右移动杆保证它不会倒下来。

这次我们主要是要跑通代码,在运行之前,配置环境可能也会有一些问题,我们建议将用的库卸载之后重装最新版以保证代码能够运行。

Git地址:https://github.com/PaddlePaddle/PARL.git

智能体构造
import numpy as np
import paddle.fluid as fluid
import parl
from parl import layers


class CartpoleAgent(parl.Agent):
    def __init__(self, algorithm, obs_dim, act_dim):
        self.obs_dim = obs_dim
        self.act_dim = act_dim
        super(CartpoleAgent, self).__init__(algorithm)

    def build_program(self):
        self.pred_program = fluid.Program()
        self.learn_program = fluid.Program()

        with fluid.program_guard(self.pred_program):
            obs = layers.data(
                name='obs', shape=[self.obs_dim], dtype='float32')
            self.act_prob = self.alg.predict(obs)

        with fluid.program_guard(self.learn_program):
            obs = layers.data(
                name='obs', shape=[self.obs_dim], dtype='float32')
            act = layers.data(name='act', shape=[1], dtype='int64')
            reward = layers.data(name='reward', shape=[], dtype='float32')
            self.cost = self.alg.learn(obs, act, reward)

    def sample(self, obs):
        obs = np.expand_dims(obs, axis=0)
        act_prob = self.fluid_executor.run(
            self.pred_program,
            feed={'obs': obs.astype('float32')},
            fetch_list=[self.act_prob])[0]
        act_prob = np.squeeze(act_prob, axis=0)
        act = np.random.choice(range(self.act_dim), p=act_prob)
        return act

    def predict(self, obs):
        obs = np.expand_dims(obs, axis=0)
        act_prob = self.fluid_executor.run(
            self.pred_program,
            feed={'obs': obs.astype('float32')},
            fetch_list=[self.act_prob])[0]
        act_prob = np.squeeze(act_prob, axis=0)
        act = np.argmax(act_prob)
        return act

    def learn(self, obs, act, reward):
        act = np.expand_dims(act, axis=-1)
        feed = {
            'obs': obs.astype('float32'),
            'act': act.astype('int64'),
            'reward': reward.astype('float32')
        }
        cost = self.fluid_executor.run(
            self.learn_program, feed=feed, fetch_list=[self.cost])[0]
        return cost

Model构造
import parl
from parl import layers


class CartpoleModel(parl.Model):
    def __init__(self, act_dim):
        act_dim = act_dim
        hid1_size = act_dim * 10

        self.fc1 = layers.fc(size=hid1_size, act='tanh')
        self.fc2 = layers.fc(size=act_dim, act='softmax')

    def forward(self, obs):
        out = self.fc1(obs)
        out = self.fc2(out)
        return out

主程序
import gym
import numpy as np
import parl
import os.path
from cartpole_agent import CartpoleAgent
from cartpole_model import CartpoleModel
from parl.utils import logger

OBS_DIM = 4
ACT_DIM = 2
LEARNING_RATE = 1e-3


def run_episode(env, agent, train_or_test='train'):
    obs_list, action_list, reward_list = [], [], []
    obs = env.reset()
    while True:
        obs_list.append(obs)
        if train_or_test == 'train':
            action = agent.sample(obs)
        else:
            action = agent.predict(obs)
        action_list.append(action)

        obs, reward, done, info = env.step(action)
        reward_list.append(reward)

        if done:
            break
    return obs_list, action_list, reward_list


def calc_reward_to_go(reward_list):
    for i in range(len(reward_list) - 2, -1, -1):
        reward_list[i] += reward_list[i + 1]
    return np.array(reward_list)


def main():
    env = gym.make("CartPole-v0")
    model = CartpoleModel(act_dim=ACT_DIM)
    alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
    agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM)

    # if the file already exists, restore parameters from it
    if os.path.exists('./model.ckpt'):
        agent.restore('./model.ckpt')

    for i in range(1000):
        obs_list, action_list, reward_list = run_episode(env, agent)
        if i % 10 == 0:
            logger.info("Episode {}, Reward Sum {}.".format(
                i, sum(reward_list)))

        batch_obs = np.array(obs_list)
        batch_action = np.array(action_list)
        batch_reward = calc_reward_to_go(reward_list)

        agent.learn(batch_obs, batch_action, batch_reward)
        if (i + 1) % 100 == 0:
            _, _, reward_list = run_episode(env, agent, train_or_test='test')
            total_reward = np.sum(reward_list)
            logger.info('Test reward: {}'.format(total_reward))

    # save the parameters to ./model.ckpt
    agent.save('./model.ckpt')


if __name__ == '__main__':
    main()

最终可以通过训练使得得分稳定达到200

!pip install gym
Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/
Requirement already satisfied: gym in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.12.1)
Requirement already satisfied: pyglet>=1.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.4.5)
Requirement already satisfied: requests>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (2.22.0)
Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.15.0)
Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.3.0)
Requirement already satisfied: numpy>=1.10.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gym) (1.16.4)
Requirement already satisfied: future in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pyglet>=1.2.0->gym) (0.18.0)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (2.8)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (2019.9.11)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (1.25.6)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.0->gym) (3.0.4)
# -*- coding: utf-8 -*-

import gym
import turtle
import numpy as np

# turtle tutorial : https://docs.python.org/3.3/library/turtle.html


def GridWorld(gridmap=None, is_slippery=False):
    if gridmap is None:
        gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG']
    env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False)
    env = FrozenLakeWapper(env)
    return env


class FrozenLakeWapper(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.max_y = env.desc.shape[0]
        self.max_x = env.desc.shape[1]
        self.t = None
        self.unit = 50

    def draw_box(self, x, y, fillcolor='', line_color='gray'):
        self.t.up()
        self.t.goto(x * self.unit, y * self.unit)
        self.t.color(line_color)
        self.t.fillcolor(fillcolor)
        self.t.setheading(90)
        self.t.down()
        self.t.begin_fill()
        for _ in range(4):
            self.t.forward(self.unit)
            self.t.right(90)
        self.t.end_fill()

    def move_player(self, x, y):
        self.t.up()
        self.t.setheading(90)
        self.t.fillcolor('red')
        self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)

    def render(self):
        if self.t == None:
            self.t = turtle.Turtle()
            self.wn = turtle.Screen()
            self.wn.setup(self.unit * self.max_x + 100,
                          self.unit * self.max_y + 100)
            self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
                                        self.unit * self.max_y)
            self.t.shape('circle')
            self.t.width(2)
            self.t.speed(0)
            self.t.color('gray')
            for i in range(self.desc.shape[0]):
                for j in range(self.desc.shape[1]):
                    x = j
                    y = self.max_y - 1 - i
                    if self.desc[i][j] == b'S':  # Start
                        self.draw_box(x, y, 'white')
                    elif self.desc[i][j] == b'F':  # Frozen ice
                        self.draw_box(x, y, 'white')
                    elif self.desc[i][j] == b'G':  # Goal
                        self.draw_box(x, y, 'yellow')
                    elif self.desc[i][j] == b'H':  # Hole
                        self.draw_box(x, y, 'black')
                    else:
                        self.draw_box(x, y, 'white')
            self.t.shape('turtle')

        x_pos = self.s % self.max_x
        y_pos = self.max_y - 1 - int(self.s / self.max_x)
        self.move_player(x_pos, y_pos)


class CliffWalkingWapper(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.t = None
        self.unit = 50
        self.max_x = 12
        self.max_y = 4

    def draw_x_line(self, y, x0, x1, color='gray'):
        assert x1 > x0
        self.t.color(color)
        self.t.setheading(0)
        self.t.up()
        self.t.goto(x0, y)
        self.t.down()
        self.t.forward(x1 - x0)

    def draw_y_line(self, x, y0, y1, color='gray'):
        assert y1 > y0
        self.t.color(color)
        self.t.setheading(90)
        self.t.up()
        self.t.goto(x, y0)
        self.t.down()
        self.t.forward(y1 - y0)

    def draw_box(self, x, y, fillcolor='', line_color='gray'):
        self.t.up()
        self.t.goto(x * self.unit, y * self.unit)
        self.t.color(line_color)
        self.t.fillcolor(fillcolor)
        self.t.setheading(90)
        self.t.down()
        self.t.begin_fill()
        for i in range(4):
            self.t.forward(self.unit)
            self.t.right(90)
        self.t.end_fill()

    def move_player(self, x, y):
        self.t.up()
        self.t.setheading(90)
        self.t.fillcolor('red')
        self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)

    def render(self):
        if self.t == None:
            self.t = turtle.Turtle()
            self.wn = turtle.Screen()
            self.wn.setup(self.unit * self.max_x + 100,
                          self.unit * self.max_y + 100)
            self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
                                        self.unit * self.max_y)
            self.t.shape('circle')
            self.t.width(2)
            self.t.speed(0)
            self.t.color('gray')
            for _ in range(2):
                self.t.forward(self.max_x * self.unit)
                self.t.left(90)
                self.t.forward(self.max_y * self.unit)
                self.t.left(90)
            for i in range(1, self.max_y):
                self.draw_x_line(
                    y=i * self.unit, x0=0, x1=self.max_x * self.unit)
            for i in range(1, self.max_x):
                self.draw_y_line(
                    x=i * self.unit, y0=0, y1=self.max_y * self.unit)

            for i in range(1, self.max_x - 1):
                self.draw_box(i, 0, 'black')
            self.draw_box(self.max_x - 1, 0, 'yellow')
            self.t.shape('turtle')

        x_pos = self.s % self.max_x
        y_pos = self.max_y - 1 - int(self.s / self.max_x)
        self.move_player(x_pos, y_pos)


if __name__ == '__main__':

    env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left
    env = CliffWalkingWapper(env)

    env.reset()
    # for step in range(100):
    #     action = np.random.randint(0, 4)
    #     obs, reward, done, info = env.step(action)
    #     print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\
    #             step, action, obs, reward, done, info))
    #     env.render()  # 渲染一帧图像

    while True:
        action = np.random.randint(0, 4)
        obs, reward, done, info = env.step(action)
        print('action {}, obs {}, reward {}, done {}, info {}'.format( \
            action, obs, reward, done, info))
        # env.render()  # 渲染一帧图像
        if done:
            break

action 2, obs 47, reward -1, done True, info {'prob': 1.0}
# !pip uninstall -y parl paddlepaddle
!pip install parl paddlepaddle -i https://mirror.baidu.com/pypi/simple
Looking in indexes: https://mirror.baidu.com/pypi/simple
Collecting parl
[?25l  Downloading https://mirror.baidu.com/pypi/packages/b4/76/429aef910909d4c58cc2bc2f4c9a14a2586367d84b9df00aecec0fe47e81/parl-1.3.2-py2.py3-none-any.whl (523kB)
[K     |████████████████████████████████| 532kB 19.1MB/s eta 0:00:01
[?25hCollecting paddlepaddle
[?25l  Downloading https://mirror.baidu.com/pypi/packages/30/93/49fff8c63732b618563237ed176b6567c325f2317bd2f56b44d3cb39b5b1/paddlepaddle-1.8.2-cp37-cp37m-manylinux1_x86_64.whl (111.2MB)
[K     |████████████████████████████████| 111.2MB 9.4MB/s eta 0:00:01
[?25hRequirement already satisfied: flask>=1.0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (1.1.1)
Requirement already satisfied: pyzmq==18.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (18.0.1)
Requirement already satisfied: tb-nightly==1.15.0a20190801 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (1.15.0a20190801)
Requirement already satisfied: tensorboardX==1.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (1.8)
Collecting psutil>=5.6.2 (from parl)
[?25l  Downloading https://mirror.baidu.com/pypi/packages/c4/b8/3512f0e93e0db23a71d82485ba256071ebef99b227351f0f5540f744af41/psutil-5.7.0.tar.gz (449kB)
[K     |████████████████████████████████| 450kB 24.1MB/s eta 0:00:01
[?25hRequirement already satisfied: pyarrow==0.13.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (0.13.0)
Requirement already satisfied: click in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (7.0)
Requirement already satisfied: visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (2.0.0b4)
Collecting flask-cors (from parl)
  Downloading https://mirror.baidu.com/pypi/packages/78/38/e68b11daa5d613e3a91e4bf3da76c94ac9ee0d9cd515af9c1ab80d36f709/Flask_Cors-3.0.8-py2.py3-none-any.whl
Requirement already satisfied: scipy>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (1.3.0)
Requirement already satisfied: cloudpickle==1.2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (1.2.1)
Requirement already satisfied: termcolor>=1.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from parl) (1.1.0)
Requirement already satisfied: protobuf>=3.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (3.10.0)
Requirement already satisfied: pathlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (1.0.1)
Requirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (4.4.0)
Requirement already satisfied: graphviz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (0.13)
Requirement already satisfied: objgraph in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (3.4.1)
Requirement already satisfied: funcsigs in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (1.0.2)
Requirement already satisfied: numpy>=1.12; python_version >= "3.5" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (1.16.4)
Requirement already satisfied: astor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (0.8.1)
Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (7.1.2)
Requirement already satisfied: prettytable in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (0.7.2)
Requirement already satisfied: rarfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (3.1)
Requirement already satisfied: requests>=2.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (2.22.0)
Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (5.1.2)
Requirement already satisfied: matplotlib; python_version >= "3.6" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (2.2.3)
Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (4.1.1.26)
Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (1.15.0)
Requirement already satisfied: nltk; python_version >= "3.5" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (3.4.5)
Requirement already satisfied: gast>=0.3.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle) (0.3.3)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.0.4->parl) (1.1.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.0.4->parl) (0.16.0)
Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.0.4->parl) (2.10.3)
Requirement already satisfied: markdown>=2.6.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl) (3.1.1)
Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl) (41.4.0)
Requirement already satisfied: absl-py>=0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl) (0.8.1)
Requirement already satisfied: grpcio>=1.6.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl) (1.26.0)
Requirement already satisfied: wheel>=0.26; python_version >= "3" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-nightly==1.15.0a20190801->parl) (0.33.6)
Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (3.8.2)
Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (1.21.0)
Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (1.0.0)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle) (2019.9.11)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle) (1.25.6)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle) (2.8)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle) (2.8.0)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle) (2019.3)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle) (0.10.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle) (2.4.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib; python_version >= "3.6"->paddlepaddle) (1.1.0)
Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.0.4->parl) (1.1.1)
Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (2.6.0)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (0.23)
Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (2.2.0)
Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (0.6.1)
Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (1.3.0)
Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (16.7.9)
Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (1.3.4)
Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (2.0.1)
Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (0.10.0)
Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (1.4.10)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (2.8.0)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version < "3.8"->flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (0.6.0)
Requirement already satisfied: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; python_version < "3.8"->flake8>=3.7.9->visualdl>=2.0.0b; python_version >= "3" and platform_system == "Linux"->parl) (7.2.0)
Building wheels for collected packages: psutil
  Building wheel for psutil (setup.py) ... [?25ldone
[?25h  Created wheel for psutil: filename=psutil-5.7.0-cp37-cp37m-linux_x86_64.whl size=261251 sha256=4e198f549b44712f0716e67db6cc5fd744adda5e5ed235fccaa10194cda79e9a
  Stored in directory: /home/aistudio/.cache/pip/wheels/37/26/59/5a301543647ca64b306025d6966dbc3b36708b7057c6169111
Successfully built psutil
Installing collected packages: psutil, flask-cors, parl, paddlepaddle
Successfully installed flask-cors-3.0.8 paddlepaddle-1.8.2 parl-1.3.2 psutil-5.7.0

import numpy as np
import paddle.fluid as fluid
import parl
from parl import layers
import gym
import numpy as np
import parl
import os.path
from cartpole_agent import CartpoleAgent
from cartpole_model import CartpoleModel
from parl.utils import logger



class CartpoleAgent(parl.Agent):
    def __init__(self, algorithm, obs_dim, act_dim):
        self.obs_dim = obs_dim
        self.act_dim = act_dim
        super(CartpoleAgent, self).__init__(algorithm)

    def build_program(self):
        self.pred_program = fluid.Program()
        self.learn_program = fluid.Program()

        with fluid.program_guard(self.pred_program):
            obs = layers.data(
                name='obs', shape=[self.obs_dim], dtype='float32')
            self.act_prob = self.alg.predict(obs)

        with fluid.program_guard(self.learn_program):
            obs = layers.data(
                name='obs', shape=[self.obs_dim], dtype='float32')
            act = layers.data(name='act', shape=[1], dtype='int64')
            reward = layers.data(name='reward', shape=[], dtype='float32')
            self.cost = self.alg.learn(obs, act, reward)

    def sample(self, obs):
        obs = np.expand_dims(obs, axis=0)
        act_prob = self.fluid_executor.run(
            self.pred_program,
            feed={'obs': obs.astype('float32')},
            fetch_list=[self.act_prob])[0]
        act_prob = np.squeeze(act_prob, axis=0)
        act = np.random.choice(range(self.act_dim), p=act_prob)
        return act

    def predict(self, obs):
        obs = np.expand_dims(obs, axis=0)
        act_prob = self.fluid_executor.run(
            self.pred_program,
            feed={'obs': obs.astype('float32')},
            fetch_list=[self.act_prob])[0]
        act_prob = np.squeeze(act_prob, axis=0)
        act = np.argmax(act_prob)
        return act

    def learn(self, obs, act, reward):
        act = np.expand_dims(act, axis=-1)
        feed = {
            'obs': obs.astype('float32'),
            'act': act.astype('int64'),
            'reward': reward.astype('float32')
        }
        cost = self.fluid_executor.run(
            self.learn_program, feed=feed, fetch_list=[self.cost])[0]
        return cost




class CartpoleModel(parl.Model):
    def __init__(self, act_dim):
        act_dim = act_dim
        hid1_size = act_dim * 10

        self.fc1 = layers.fc(size=hid1_size, act='tanh')
        self.fc2 = layers.fc(size=act_dim, act='softmax')

    def forward(self, obs):
        out = self.fc1(obs)
        out = self.fc2(out)
        return out



OBS_DIM = 4
ACT_DIM = 2
LEARNING_RATE = 1e-3


def run_episode(env, agent, train_or_test='train'):
    obs_list, action_list, reward_list = [], [], []
    obs = env.reset()
    while True:
        obs_list.append(obs)
        if train_or_test == 'train':
            action = agent.sample(obs)
        else:
            action = agent.predict(obs)
        action_list.append(action)

        obs, reward, done, info = env.step(action)
        reward_list.append(reward)

        if done:
            break
    return obs_list, action_list, reward_list


def calc_reward_to_go(reward_list):
    for i in range(len(reward_list) - 2, -1, -1):
        reward_list[i] += reward_list[i + 1]
    return np.array(reward_list)


def main():
    env = gym.make("CartPole-v0")
    model = CartpoleModel(act_dim=ACT_DIM)
    alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
    agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM)

    # if the file already exists, restore parameters from it
    if os.path.exists('./model.ckpt'):
        agent.restore('./model.ckpt')

    for i in range(1000):
        obs_list, action_list, reward_list = run_episode(env, agent)
        if i % 10 == 0:
            logger.info("Episode {}, Reward Sum {}.".format(
                i, sum(reward_list)))

        batch_obs = np.array(obs_list)
        batch_action = np.array(action_list)
        batch_reward = calc_reward_to_go(reward_list)

        agent.learn(batch_obs, batch_action, batch_reward)
        if (i + 1) % 100 == 0:
            _, _, reward_list = run_episode(env, agent, train_or_test='test')
            total_reward = np.sum(reward_list)
            logger.info('Test reward: {}'.format(total_reward))

    # save the parameters to ./model.ckpt
    agent.save('./model.ckpt')


if __name__ == '__main__':
    main()

运行代码请点击:https://aistudio.baidu.com/aistudio/projectdetail/625901?shared=1

欢迎三连!

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值