Deeplearning4j 实战 (9):强化学习 -- Cartpole任务的训练和效果测试

Eclipse Deeplearning4j GitChat课程https://gitbook.cn/gitchat/column/5bfb6741ae0e5f436e35cd9f
Eclipse Deeplearning4j 系列博客https://blog.csdn.net/wangongxi
Eclipse Deeplearning4j Githubhttps://github.com/eclipse/deeplearning4j

在之前的博客中,我用Deeplearning4j构建深度神经网络来解决监督、无监督的机器学习问题。但除了这两类问题外,强化学习也是机器学习中一个重要的分支,并且Deeplearning4j的子项目--Rl4j提供了对部分强化学习算法的支持。这里,就以强化学习中的经典任务--Cartpole问题作为学习Rl4j的入门例子,讲解从环境搭建、模型训练再到最后的效果评估的结果。

Cartpole描述的问题可以认为是:在一辆小车上竖立一根杆子,然后给小车一个推或者拉的力,使得杆子尽量保持平衡不滑倒。更详细的描述可参见openai官网上关于Cartpole问题的解释:https://gym.openai.com/envs/CartPole-v0

接着给出强化学习的一些概念:environment,action,reward

environment:描述强化学习问题中的外部环境,比如:Cartpole问题中杆子的角度,小车的位置、速度等。

action:在不同外部环境条件下采取的动作,比如:Cartpole问题中对于小车施加推或者拉的力。action可以是离散的集合,也可以是连续的。

reward:对于agent/network作出的action后获取的回报/评价。比如:Cartpole问题中如果施加的力可以继续让杆子保持平衡,那reward就+1。

在描述reward这个概念时,提到了agent这个概念,在实际应用中,agent可以用神经网络来实现。

对于强化学习训练后的agent来说,学习到的是如何在变化中的environment和reward选择action的能力。通常有两种学习策略可以选择:Policy-Based和Value-Based。 Policy-Based直接学习action,通过Policy Gradient来更新模型参数,而相对的,Value-Based是最优化action所带来的reward(action-value function,Q-function)来间接选取action。一般认为如果action是离散的,那么Value-Based会优于Policy-Based,而连续的action则相反。在这里主要讨论Value-Based的学习策略,或者更具体的说Q-learning的问题。对于Policy-Based还有Model-Based不做讨论。

Q-learning的概念早在20多年前就已经提出,再与近年来流行的深度神经网络结合产生了DQN的概念。Q-learning的目标是最大化Q值从而学习到选取action的策略。Q-leaning学习的策略公式:

Q(st,at)←Q(st,at)+α[rt+1+λmaxaQ(st+1,a)−Q(st,at)]

对于这里主要讨论的Catpole问题,我们也采用Q-learning来实现。

可以看到,与监督学习相比,强化学习多了action,environment等概念。虽然可以将reward类比成监督学习中的label(或者反过来,label也可以认为是强化学习中最终的reward),但通过action与environment不断的交互甚至改变environment这一特点,是监督学习中所没有的。在构建应用的时候,监督学习的学习的目标:label,灌入的数据都是一个定值。比如,图像的分类的问题,在用CNN训练的时候,图片本身不发生变化,label也不会发生变化,唯一变化的是神经网络中的权重值。但强化学习在训练的时候,除了神经网络中的权重会发生变化(如果用NN建模的话),environment、reward等都会发生动态的变化。这样构建合适正确的训练数据会比较麻烦,容易出错,所以对于CartPole问题,我们可以采用openAI提供的强化学习开发环境gym来训练/测试agent。

gym的官方地址:https://gym.openai.com/

gym提供了棋类、视频游戏等强化学习问题的学习/测试/算法效果比较的环境。这里要处理的Cartpole问题,gym也提供了环境的支持。但是,除了python,gym对其他语言的支持不是很友好,所以为了可以获取gym中的数据,RL4j提供了对gym-http-api(https://github.com/openai/gym-http-api)调用的包装类。gym-http-api是为了方便除python外的其他语言也可以使用gym环境数据的一个REST接口。简单来说,对于像RL4j这样以Java实现的强化学习算法库可以通过gym-http-api获取gym提供的数据。

gym的REST接口的安装可以参见之前给出的github地址,里面有详细的描述。下面先给出gym-http-api的安装和启动过程的截图:

下面就结合上面说的内容,给出RL4j的Catpole实现逻辑

1. 定义Q-learning的参数以及神经网络结构,两者共同决定DQN的属性

2. 定义读取gym数据的包装类对象

3. 训练DQN并保存模型

4. 加载保存的模型并测试

这里先贴下需要的Maven依赖以及代码版本

 

  <properties>
	  <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> 
	  <nd4j.version>0.8.0</nd4j.version>  
	  <dl4j.version>0.8.0</dl4j.version>  
	  <datavec.version>0.8.0</datavec.version>  
	  <rl4j.version>0.8.0</rl4j.version>
	  <scala.binary.version>2.10</scala.binary.version>  
  </properties>
  <dependencies>
	<dependency>  
		<groupId>org.nd4j</groupId>  
		<artifactId>nd4j-native</artifactId>   
		<version>${nd4j.version}</version>  
	</dependency>  
        <dependency>  
		<groupId>org.deeplearning4j</groupId>  
		<artifactId>deeplearning4j-core</artifactId>  
		<version>${dl4j.version}</version>  
	</dependency>  
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>rl4j-core</artifactId>
            <version>${rl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>rl4j-gym</artifactId>
            <version>${rl4j.version}</version>
        </dependency>

  </dependencies>

 

 

 

 

 

第一部分的代码如下:

 

    public static QLearning.QLConfiguration CARTPOLE_QL =
            new QLearning.QLConfiguration(
                    123,    //Random seed
                    200,    //Max step By epoch
                    150000, //Max step
                    150000, //Max size of experience replay
                    32,     //size of batches
                    500,    //target update (hard)
                    10,     //num step noop warmup
                    0.01,   //reward scaling
                    0.99,   //gamma
                    1.0,    //td-error clipping
                    0.1f,   //min epsilon
                    1000,   //num step for eps greedy anneal
                    true    //double DQN
            );

    public static DQNFactoryStdDense.Configuration CARTPOLE_NET = DQNFactoryStdDense.Configuration.builder()            												.l2(0.001)            																        .learningRate(0.0005)
       								.numHiddenNodes(16)
           							.numLayer(3)
            							.build();

 

第一部分中定义Q-learning的参数,包括每一轮的训练的可采取的action的步数,最大步数以及存储过往action的最大步数等。除此以外,DQNFactoryStdDense用来定义基于MLP的DQN网络结构,包括网络深度等常见参数。这里的代码定义的是一个三层(只有一层隐藏层)的全连接神经网络。

接下来,定义两个方法分别用于训练和测试。catpole方法用于训练DQN,而loadCartpole则用于测试。

训练:

 

    public static void cartPole() {

        //record the training data in rl4j-data in a new folder (save)
        DataManager manager = new DataManager(true);

        //define the mdp from gym (name, render)
        GymEnv<Box, Integer, DiscreteSpace> mdp = null;
        try {
            mdp = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", false, false);
        } catch (RuntimeException e){
            System.out.print("To run this example, download and start the gym-http-api repo found at https://github.com/openai/gym-http-api.");
        }
        //define the training
        QLearningDiscreteDense<Box> dql = new QLearningDiscreteDense<Box>(mdp, CARTPOLE_NET, CARTPOLE_QL, manager);

        //train
        dql.train();

        //get the final policy
        DQNPolicy<Box> pol = dql.getPolicy();

        //serialize and save (serialization showcase, but not required)
        pol.save("/tmp/pol1");

        //close the mdp (close http)
        mdp.close();

    }


测试:

 

 

    public static void loadCartpole(){

        //showcase serialization by using the trained agent on a new similar mdp (but render it this time)

        //define the mdp from gym (name, render)
        GymEnv<Box, Integer, DiscreteSpace> mdp2 = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", true, false);

        //load the previous agent
        DQNPolicy<Box> pol2 = DQNPolicy.load("/tmp/pol1");

        //evaluate the agent
        double rewards = 0;
        for (int i = 0; i < 1000; i++) {
            mdp2.reset();
            double reward = pol2.play(mdp2);
            rewards += reward;
            Logger.getAnonymousLogger().info("Reward: " + reward);
        }

        Logger.getAnonymousLogger().info("average: " + rewards/1000);
        
        mdp2.close();

    }

在训练模型的方法中,包含了第二、三步的内容。首先需要定义gym数据读取对象,即代码中的GymEnv<Box, Integer, DiscreteSpace> mdp。它会通过gym-http-api的接口读取训练数据。接着,将第一步中定义的Q-learning的相关参数,神经网络结构作为参数传入DQN训练的包装类中。其中DataManager的作用是用来管理训练数据。

测试部分的代码实现了之前说的第四步,即加载策略模型并进行测试的过程。在测试的过程中,将每次action的reward打上log,并最后求取平均的reward。

训练的过程截图如下:

 

最后我们其实最关心的还是这个模型的效果。纯粹通过平均reward的数值大小可能并不是非常的直观,因此这里给出一张gif的效果图:

总结一下Cartpole问题的整个解决过程。首先我们明确,这是一个强化学习的问题,而不是传统的监督学习,因为涉及到与环境的交互等因素。然后,利用openAI提供的强化学习开发环境gym来构建训练平台,而RL4j则可以定义并训练DQN。最后的效果就是上面这张gif图片。需要注意的是,这张gif效果图并非是RL4j直接生成的,而是通过xvfb命令截取虚拟monitor的在每个action后的效果拼接起来的图。具体可先查阅xvfb的相关内容。

deeplearning4j是基于java的深度学习库,当然,它有许多特点,但暂时还没学那么深入,所以就不做介绍了 需要学习dl4j,无从下手,就想着先看看官网的examples,于是,下载了examples程序,结果无法运行,总是出错,如下: 查看一周的错误,也没有成功,马上就要放弃了,结果今天在论坛一大牛指导下,终于成功跑起,下面,将心酸的环境配置过程记录如下,以备自己以后查阅,同时,也希望各种高手可以指点,毕竟,本人还是菜鸟一枚 1.安装JAVA运行环境 该部分,网上有许多教程,这里不再赘述,首先,就是安装一个JDK,然后,再安装一个自己喜欢的IED,这里,以eclispe为例 好了,java的运行环境配置好了,接下来,开始配置dl4j的运行环境,它的官网上给了好复杂的设置步骤,照着做看一些后,发现根本无法进行,结果发现,不需要全部设置完成,就可以运行它的例子了,所以,本人并没有按照官网的教程全部设置,只是设置到了可以运行官网的examples为止,可能存在隐患吧,但本人能力有限,实在无从下手,还期待高手指定 2.按照Maven 按照教程安装Maven,该教程讲述非常详细 (1)下载Maven3,3,3,以win7 64位为例 下载地址:https://maven.apache.org/download.cgi (2)将Maven解压到某个文件夹中,这里以“C:\Program Files\apache-maven-3.3.3”为例 (3)配置环境变量:将maven中的bin的路径添加到system variables的PATH中 (4)测试maven是否安装成功 在命令行中输入mvn -version 如果如下下图所示结果,证明配置正确 3. 下载dl4j的examples,网址为: https://github.com/deeplearning4j/dl4j-0.4-examples 4.打开eclipse,导入刚刚下载的dl4j的examples,具体地: 打开eclipse后->File->import->Maven Existing Maven Projects,在Root Directory中选择examples的文件夹 然后,Finish 这样,examples被成功导入 当然,由于Maven会自动导入程序所需的jar文件(在配置文件pom.xml中所提及),所以,会花费一些时间自动下载这些文件 点击运行,出现如下错误: 这个问题困扰了本人一周,终于解决,是因为系统缺少dll文件所致 5. 下载dll文件,地址为https://www.dropbox.com/s/6p8yn3fcf230rxy/ND4J_Win64_OpenBLAS-v0.2.14.zip?dl=1 下载后,将该文件随意放入一个文件夹中,这里以“C:/BLAS”为例 将所有下载得到的dll文件放入该文件夹,并且,将该路径添加至环境变量Path中 6.此时,再运行刚刚的examples,发现程序终于可以正常运行了!
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

wangongxi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值