Eclipse Deeplearning4j GitChat课程:https://gitbook.cn/gitchat/column/5bfb6741ae0e5f436e35cd9f
Eclipse Deeplearning4j 系列博客:https://blog.csdn.net/wangongxi
Eclipse Deeplearning4j Github:https://github.com/eclipse/deeplearning4j
在之前的博客中,我用Deeplearning4j构建深度神经网络来解决监督、无监督的机器学习问题。但除了这两类问题外,强化学习也是机器学习中一个重要的分支,并且Deeplearning4j的子项目--Rl4j提供了对部分强化学习算法的支持。这里,就以强化学习中的经典任务--Cartpole问题作为学习Rl4j的入门例子,讲解从环境搭建、模型训练再到最后的效果评估的结果。
Cartpole描述的问题可以认为是:在一辆小车上竖立一根杆子,然后给小车一个推或者拉的力,使得杆子尽量保持平衡不滑倒。更详细的描述可参见openai官网上关于Cartpole问题的解释:https://gym.openai.com/envs/CartPole-v0
接着给出强化学习的一些概念:environment,action,reward
environment:描述强化学习问题中的外部环境,比如:Cartpole问题中杆子的角度,小车的位置、速度等。
action:在不同外部环境条件下采取的动作,比如:Cartpole问题中对于小车施加推或者拉的力。action可以是离散的集合,也可以是连续的。
reward:对于agent/network作出的action后获取的回报/评价。比如:Cartpole问题中如果施加的力可以继续让杆子保持平衡,那reward就+1。
在描述reward这个概念时,提到了agent这个概念,在实际应用中,agent可以用神经网络来实现。
对于强化学习训练后的agent来说,学习到的是如何在变化中的environment和reward选择action的能力。通常有两种学习策略可以选择:Policy-Based和Value-Based。 Policy-Based直接学习action,通过Policy Gradient来更新模型参数,而相对的,Value-Based是最优化action所带来的reward(action-value function,Q-function)来间接选取action。一般认为如果action是离散的,那么Value-Based会优于Policy-Based,而连续的action则相反。在这里主要讨论Value-Based的学习策略,或者更具体的说Q-learning的问题。对于Policy-Based还有Model-Based不做讨论。
Q-learning的概念早在20多年前就已经提出,再与近年来流行的深度神经网络结合产生了DQN的概念。Q-learning的目标是最大化Q值从而学习到选取action的策略。Q-leaning学习的策略公式:
Q(st,at)←Q(st,at)+α[rt+1+λmaxaQ(st+1,a)−Q(st,at)]
对于这里主要讨论的Catpole问题,我们也采用Q-learning来实现。
可以看到,与监督学习相比,强化学习多了action,environment等概念。虽然可以将reward类比成监督学习中的label(或者反过来,label也可以认为是强化学习中最终的reward),但通过action与environment不断的交互甚至改变environment这一特点,是监督学习中所没有的。在构建应用的时候,监督学习的学习的目标:label,灌入的数据都是一个定值。比如,图像的分类的问题,在用CNN训练的时候,图片本身不发生变化,label也不会发生变化,唯一变化的是神经网络中的权重值。但强化学习在训练的时候,除了神经网络中的权重会发生变化(如果用NN建模的话),environment、reward等都会发生动态的变化。这样构建合适正确的训练数据会比较麻烦,容易出错,所以对于CartPole问题,我们可以采用openAI提供的强化学习开发环境gym来训练/测试agent。
gym的官方地址:https://gym.openai.com/
gym提供了棋类、视频游戏等强化学习问题的学习/测试/算法效果比较的环境。这里要处理的Cartpole问题,gym也提供了环境的支持。但是,除了python,gym对其他语言的支持不是很友好,所以为了可以获取gym中的数据,RL4j提供了对gym-http-api(https://github.com/openai/gym-http-api)调用的包装类。gym-http-api是为了方便除python外的其他语言也可以使用gym环境数据的一个REST接口。简单来说,对于像RL4j这样以Java实现的强化学习算法库可以通过gym-http-api获取gym提供的数据。
gym的REST接口的安装可以参见之前给出的github地址,里面有详细的描述。下面先给出gym-http-api的安装和启动过程的截图:
下面就结合上面说的内容,给出RL4j的Catpole实现逻辑
1. 定义Q-learning的参数以及神经网络结构,两者共同决定DQN的属性
2. 定义读取gym数据的包装类对象
3. 训练DQN并保存模型
4. 加载保存的模型并测试
这里先贴下需要的Maven依赖以及代码版本
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<nd4j.version>0.8.0</nd4j.version>
<dl4j.version>0.8.0</dl4j.version>
<datavec.version>0.8.0</datavec.version>
<rl4j.version>0.8.0</rl4j.version>
<scala.binary.version>2.10</scala.binary.version>
</properties>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>rl4j-core</artifactId>
<version>${rl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>rl4j-gym</artifactId>
<version>${rl4j.version}</version>
</dependency>
</dependencies>
第一部分的代码如下:
public static QLearning.QLConfiguration CARTPOLE_QL =
new QLearning.QLConfiguration(
123, //Random seed
200, //Max step By epoch
150000, //Max step
150000, //Max size of experience replay
32, //size of batches
500, //target update (hard)
10, //num step noop warmup
0.01, //reward scaling
0.99, //gamma
1.0, //td-error clipping
0.1f, //min epsilon
1000, //num step for eps greedy anneal
true //double DQN
);
public static DQNFactoryStdDense.Configuration CARTPOLE_NET = DQNFactoryStdDense.Configuration.builder() .l2(0.001) .learningRate(0.0005)
.numHiddenNodes(16)
.numLayer(3)
.build();
第一部分中定义Q-learning的参数,包括每一轮的训练的可采取的action的步数,最大步数以及存储过往action的最大步数等。除此以外,DQNFactoryStdDense用来定义基于MLP的DQN网络结构,包括网络深度等常见参数。这里的代码定义的是一个三层(只有一层隐藏层)的全连接神经网络。
接下来,定义两个方法分别用于训练和测试。catpole方法用于训练DQN,而loadCartpole则用于测试。
训练:
public static void cartPole() {
//record the training data in rl4j-data in a new folder (save)
DataManager manager = new DataManager(true);
//define the mdp from gym (name, render)
GymEnv<Box, Integer, DiscreteSpace> mdp = null;
try {
mdp = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", false, false);
} catch (RuntimeException e){
System.out.print("To run this example, download and start the gym-http-api repo found at https://github.com/openai/gym-http-api.");
}
//define the training
QLearningDiscreteDense<Box> dql = new QLearningDiscreteDense<Box>(mdp, CARTPOLE_NET, CARTPOLE_QL, manager);
//train
dql.train();
//get the final policy
DQNPolicy<Box> pol = dql.getPolicy();
//serialize and save (serialization showcase, but not required)
pol.save("/tmp/pol1");
//close the mdp (close http)
mdp.close();
}
测试:
public static void loadCartpole(){
//showcase serialization by using the trained agent on a new similar mdp (but render it this time)
//define the mdp from gym (name, render)
GymEnv<Box, Integer, DiscreteSpace> mdp2 = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", true, false);
//load the previous agent
DQNPolicy<Box> pol2 = DQNPolicy.load("/tmp/pol1");
//evaluate the agent
double rewards = 0;
for (int i = 0; i < 1000; i++) {
mdp2.reset();
double reward = pol2.play(mdp2);
rewards += reward;
Logger.getAnonymousLogger().info("Reward: " + reward);
}
Logger.getAnonymousLogger().info("average: " + rewards/1000);
mdp2.close();
}
在训练模型的方法中,包含了第二、三步的内容。首先需要定义gym数据读取对象,即代码中的GymEnv<Box, Integer, DiscreteSpace> mdp。它会通过gym-http-api的接口读取训练数据。接着,将第一步中定义的Q-learning的相关参数,神经网络结构作为参数传入DQN训练的包装类中。其中DataManager的作用是用来管理训练数据。
测试部分的代码实现了之前说的第四步,即加载策略模型并进行测试的过程。在测试的过程中,将每次action的reward打上log,并最后求取平均的reward。
训练的过程截图如下:
最后我们其实最关心的还是这个模型的效果。纯粹通过平均reward的数值大小可能并不是非常的直观,因此这里给出一张gif的效果图:
总结一下Cartpole问题的整个解决过程。首先我们明确,这是一个强化学习的问题,而不是传统的监督学习,因为涉及到与环境的交互等因素。然后,利用openAI提供的强化学习开发环境gym来构建训练平台,而RL4j则可以定义并训练DQN。最后的效果就是上面这张gif图片。需要注意的是,这张gif效果图并非是RL4j直接生成的,而是通过xvfb命令截取虚拟monitor的在每个action后的效果拼接起来的图。具体可先查阅xvfb的相关内容。