经过我的测试 Java 下面做强化学习有以下几个方案,Deeplearning4j(后面的版本已经移除了rl4j库官方也不维护了,做一些小Work还是可以的),DJL(亚马逊开源,目前这个在 Java 深度学习方向是比较全面的),另外就是强化学习仿真平台 gym、mujoco 这些都可以在 Java 平台找到使用方案。
首先是 Deeplearning4j 里面的 rl4j 项目可以做一些基本的强化学习 work,内置了 DQN、Q-learning、A3C等等基础算法:
下面是一个 Deeplearning4j基础示例:
// q-learning 超参
QLearning.QLConfiguration QL_CONFIG =
new QLearning.QLConfiguration(
123, //Random seed
1000, //Max step Every epoch 批次下最大执行的步数
200*1000, //Max step 总执行的部署
100*1000, //Max size of experience replay 记忆数据
400, //size of batches
100, //target update (hard) 每10次更新一次参数
0, //num step noop warmup 步数从0开始
0.01, //reward scaling
0.9, //gamma
1.0, //td-error clipping
0.1f, //min epsilon
100, //num step for eps greedy anneal
false //double DQN
);
// dqn 网络
DQNFactoryStdDense.Configuration DQN_NET =
DQNFactoryStdDense.Configuration.builder()
.updater(new Adam(0.01))
.numLayer(5)
.numHiddenNodes(16)
.build();
// 游戏交互体,封装了 step reward 等,游戏是否结束等等,实现 MDP接口
GameMDP mdp = new GameMDP();
QLearningDiscreteDense dqn = new QLearningDiscreteDense(mdp, DQN_NET, QL_CONFIG, new DataManager());
// 定义 checkpoint 保存
// 训练过程不需要判断 isDone 或者参数更新了,这些封装在rlf4里面
DQNPolicy pol = dqn.getPolicy();
dqn.train();
pol.save("game.policy");
mdp.close();
另外就是正经项目需要使用的 DJL 了,流程是继承官方的 ENV 接口用于和环境交互,然后实现智能体 Agent 并定义网络:
比如实现一个 PPO 算法,定义智能体:
package algorithm.ppo;
import ai.djl.engine.Engine;
import ai.djl.modality.rl.ActionSpace;
import ai.djl.modality.rl.agent.RlAgent;
import ai.djl.modality.rl.env.RlEnv;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.GradientCollector;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.translate.Batchifier;
import algorithm.CommonParameter;
import utils.ActionSampler;
import utils.Helper;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
public class PPOAgent implements RlAgent {
private Random random;
private Trainer policyTrainer;
private Trainer valueTrainer;
private float rewardDiscount;
private Batchifier batchifier;
public PPOAgent(Random random, Trainer policyTrainer, Trainer valueTrainer, float rewardDiscount) {
this.random = random;
this.policyTrainer = policyTrainer;
this.valueTrainer = valueTrainer;
this.rewardDiscount = rewardDiscount;
this.batchifier = Batchifier.STACK;
}
@Override
public NDList chooseAction(RlEnv env, boolean training) {
NDList[] inputs = buildInputs(env.getObservation());
NDArray actionScores =
policyTrainer.evaluate(batchifier.batchify(inputs)).singletonOrThrow().squeeze(-1);
int action;
if (training) {
action = ActionSampler.sampleMultinomial(actionScores, random);
} else {
action = Math.toIntExact(actionScores.argMax().getLong());
}
ActionSpace actionSpace = env.getActionSpace();
return actionSpace.get(action);
}
@Override
public void trainBatch(RlEnv.Step[] batchSteps) {
TrainingListener.BatchData batchData =
new TrainingListener.BatchData(null, new ConcurrentHashMap<>(), new ConcurrentHashMap<>());
NDList[] preObservations = buildBatchPreObservation(batchSteps);
NDList[] actions = buildBatchAction(batchSteps);
NDList[] rewards = buildBatchReward(batchSteps);
boolean[] dones = buildBatchDone(batchSteps);
NDList policyOutput = policyTrainer.evaluate(batchifier.batchify(preObservations));
NDArray distribution = Helper.gather(policyOutput.singletonOrThrow().duplicate(), batchifier.batchify(actions).singletonOrThrow().toIntArray());
NDList valueOutput = valueTrainer.evaluate(batchifier.batchify(preObservations));
NDArray values = valueOutput.singletonOrThrow().duplicate();
NDList estimates = estimateAdvantage(values.duplicate(), batchifier.batchify(rewards).singletonOrThrow(), dones);
NDArray expectedReturns = estimates.get(0);
NDArray advantages = estimates.get(1);
// update critic
NDList valueOutputUpdated = valueTrainer.forward(batchifier.batchify(preObservations));
NDArray valuesUpdated = valueOutputUpdated.singletonOrThrow();
NDArray lossCritic = (expectedReturns.sub(valuesUpdated)).square().mean();
try (GradientCollector collector = Engine.getInstance().newGradientCollector()) {
collector.backward(lossCritic);
}
// update policy
NDList policyOutputUpdated = policyTrainer.forward(batchifier.batchify(preObservations));
NDArray distributionUpdated = Helper.gather(policyOutputUpdated.singletonOrThrow(), batchifier.batchify(actions).singletonOrThrow().toIntArray());
NDArray ratios = distributionUpdated.div(distribution);
NDArray surr1 = ratios.mul(advantages);
NDArray surr2 = ratios.clip(PPOParameter.RATIO_LOWER_BOUND, PPOParameter.RATIO_UPPER_BOUND).mul(advantages);
NDArray lossActor = surr1.minimum(surr2).mean().neg();
try (GradientCollector collector = Engine.getInstance().newGradientCollector()) {
collector.backward(lossActor);
}
// policyTrainer.notifyListeners(listener -> listener.onTrainingBatch(policyTrainer, batchData));
// valueTrainer.notifyListeners(listener -> listener.onTrainingBatch(valueTrainer, batchData));
}
private NDList[] buildInputs(NDList observation) {
return new NDList[]{observation};
}
public NDList[] buildBatchPreObservation(RlEnv.Step[] batchSteps) {
NDList[] result = new NDList[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
result[i] = batchSteps[i].getPreObservation();
}
return result;
}
public NDList[] buildBatchAction(RlEnv.Step[] batchSteps) {
NDList[] result = new NDList[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
result[i] = batchSteps[i].getAction();
}
return result;
}
public NDList[] buildBatchPostObservation(RlEnv.Step[] batchSteps) {
NDList[] result = new NDList[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
result[i] = batchSteps[i].getPostObservation();
}
return result;
}
public NDList[] buildBatchReward(RlEnv.Step[] batchSteps) {
NDList[] result = new NDList[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
result[i] = new NDList(batchSteps[i].getReward().expandDims(0));
}
return result;
}
public boolean[] buildBatchDone(RlEnv.Step[] batchSteps) {
boolean[] resultData = new boolean[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
resultData[i] = batchSteps[i].isDone();
}
return resultData;
}
private NDList estimateAdvantage(NDArray values, NDArray rewards, boolean[] masks) {
NDManager manager = rewards.getManager();
NDArray deltas = manager.create(rewards.getShape());
NDArray advantages = manager.create(rewards.getShape());
float prevValue = 0;
float prevAdvantage = 0;
for (int i = (int) rewards.getShape().get(0) - 1; i >= 0; i--) {
NDIndex index = new NDIndex(i);
int mask = masks[i] ? 0 : 1;
deltas.set(index, rewards.get(i).add(CommonParameter.GAMMA * prevValue * mask).sub(values.get(i)));
advantages.set(index, deltas.get(i).add(CommonParameter.GAMMA * CommonParameter.GAE_LAMBDA * prevAdvantage * mask));
prevValue = values.getFloat(i);
prevAdvantage = advantages.getFloat(i);
}
NDArray expected_returns = values.add(advantages);
NDArray advantagesMean = advantages.mean();
NDArray advantagesStd = advantages.sub(advantagesMean).pow(2).sum().div(advantages.size() - 1).sqrt();
advantages = advantages.sub(advantagesMean).div(advantagesStd);
return new NDList(expected_returns, advantages);
}
private NDArray getSample(NDManager subManager, NDArray array, int[] index) {
Shape shape = Shape.update(array.getShape(), 0, index.length);
NDArray sample = subManager.zeros(shape, array.getDataType());
for (int i = 0; i < index.length; i++) {
sample.set(new NDIndex(i), array.get(index[i]));
}
return sample;
}
}
定义网络以及ENV并启动训练:
public static TrainingResult runExample(String[] args) throws IOException {
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Engine.getInstance().setRandomSeed(0);
int epoch = 500;
int batchSize = 64;
int replayBufferSize = 2048;
int gamesPerEpoch = 128;
// Validation is deterministic, thus one game is enough
int validationGamesPerEpoch = 1;
float rewardDiscount = 0.9f;
if (arguments.getLimit() != Long.MAX_VALUE) {
gamesPerEpoch = Math.toIntExact(arguments.getLimit());
}
Random random = new Random(0);
NDManager mainManager = NDManager.newBaseManager();
CartPole env = new CartPole(mainManager, random, batchSize, replayBufferSize);
int stateSpaceDim = (int) env.getObservation().singletonOrThrow().getShape().get(0);
int actionSpaceDim = env.getActionSpace().size();
Model policyModel = Model.newInstance("discrete_policy_model");
BaseModelBlock policyNet = new DiscretePolicyModelBlock(actionSpaceDim, PPOParameter.POLICY_MODEL_HIDDEN_SIZE);
policyModel.setBlock(policyNet);
Model valueModel = Model.newInstance("critic_value_model");
BaseModelBlock valueNet = new CriticValueModelBlock(PPOParameter.CRITIC_MODEL_HIDDEN_SIZE);
valueModel.setBlock(valueNet);
DefaultTrainingConfig policyConfig = setupTrainingConfig();
DefaultTrainingConfig valueConfig = setupTrainingConfig();
Trainer policyTrainer = policyModel.newTrainer(policyConfig);
Trainer valueTrainer = valueModel.newTrainer(valueConfig);
policyTrainer.initialize(new Shape(batchSize, 4));
valueTrainer.initialize(new Shape(batchSize, 4));
policyTrainer.notifyListeners(listener -> listener.onTrainingBegin(policyTrainer));
valueTrainer.notifyListeners(listener -> listener.onTrainingBegin(valueTrainer));
PPOAgent agent = new PPOAgent(random, policyTrainer, valueTrainer, rewardDiscount);
for (int i = 0; i < epoch; i++) {
int episode = 0;
int size = 0;
while (size < replayBufferSize) {
episode++;
float result = env.runEnvironment(agent, true);
size += (int) result;
System.out.println("[" + episode + "]train:" + result);
}
for (int j = 0; j < gamesPerEpoch; j++) {
RlEnv.Step[] batchSteps = env.getBatch();
agent.trainBatch(batchSteps);
policyTrainer.step();
valueTrainer.step();
// System.out.println("train[" + i + "-" + j + "]:" + result);
}
for (int j = 0; j < validationGamesPerEpoch; j++) {
float result = env.runEnvironment(agent, false);
System.out.println("test:" + result);
}
}
policyTrainer.notifyListeners(listener -> listener.onTrainingEnd(policyTrainer));
valueTrainer.notifyListeners(listener -> listener.onTrainingEnd(valueTrainer));
return policyTrainer.getTrainingResult();
}
FlappyBrid
下面是使用 DJL 训练 DQN 算法实现 FlappyBrid 智能体的例子。
第一步是实现游戏,游戏代码从github上找的,需要改变得部分就是每执行一帧,将游戏得状态返回(小鸟得位置、是否碰撞、整个游戏画面、是否通过水管)用于奖励修改以及神经网络输入:
/**
* @desc : 执行一个步骤,返回游戏状态
* @auth : tyf
* @date : 2023-10-20 10:32:30
*/
public GameValue step(int action) {
if (action == 1) {
bird.birdFlap();
}
// 返回一个游戏状态
GameValue gameValue = stepFrame();
if (this.withGraphics) {
try {
// 将步骤执行后的游戏画面返回,后续作为神经网络的输入
gameValue.setImage(currentImg);
Thread.sleep(FPS);
} catch (InterruptedException e) {
e.printStackTrace();
}
// 画面渲染
repaint();
}
return gameValue;
}
/**
* @desc : 游戏功能测试
* @auth : tyf
* @date : 2023-10-20 10:30:26
*/
public static void checkGame(){
FlappyBird flappyBird = new FlappyBird(true);
while (true){
// 动作是 1 和 0 随机生成
int action = new Random().nextInt(10)%2==0?1:0;
// 每走一步返回一个游戏状态
// 这个状态会交给env用于判断奖励或者游戏结束
FlappyBird.GameValue v = flappyBird.step(action);
// 获取每一帧游戏画面,作为神经网络的输入
BufferedImage img = v.getImage();
// 游戏结束则重启
if(v.isDone()){
flappyBird.restartGame();
}
System.out.println("当前状态:"+v);
}
}
后续就是定义环境和智能体,Env 的作用是将环境和智能体进行结合,将游戏和强化学习框架进行结合,一般调用 Env 的 step() 函数 => 游戏的 step() 函数,再判断奖励设置、游戏是否结束等等:
package org.env;
import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.modality.rl.ActionSpace;
import ai.djl.modality.rl.LruReplayBuffer;
import ai.djl.modality.rl.ReplayBuffer;
import ai.djl.modality.rl.agent.RlAgent;
import ai.djl.modality.rl.env.RlEnv;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import org.game.FlappyBird;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayDeque;
import java.util.Queue;
/**
* @desc : 封装游戏和训练环境
* @auth : tyf
* @date : 2023-10-20 11:55:25
*/
public class GameEnv implements RlEnv {
// 游戏实体
private FlappyBird flappyBird;
// 训练环境
private final NDManager manager; // 张量管理器
private final LruReplayBuffer replayBuffer; // 经验回访池
private NDList currentObservation; // 游戏当前帧的观测空间
private ActionSpace actionSpace; // 动作空间
// 观测空间缓存,每次存放4个,4个张量使用 stack 拼接成一个张量后输入网络
private final Queue<NDArray> imgQueue = new ArrayDeque<>(4);
// 记录游戏和训练检测以及中间临时变量参数
public static int trainStep = 0;
// 神经网络
public static Model model;
/**
* @desc : 游戏环境创建
* @auth : tyf
* @date : 2023-10-20 14:45:08
*/
public GameEnv(FlappyBird flappyBird, NDManager manager, LruReplayBuffer replayBuffer,Model model) {
// 游戏实体
this.flappyBird = flappyBird;
// 张量管理器
this.manager = manager;
// 经验回放池
this.replayBuffer = replayBuffer;
// 神经网络
this.model = model;
// 初始化动作空间,这里游戏只有两个动作
actionSpace = new ActionSpace();
actionSpace.add(new NDList(manager.create(new int[]{0,1}))); // DO_NOTHING 不飞
actionSpace.add(new NDList(manager.create(new int[]{1,0}))); // FLAP 飞
// 初始化游戏当前观测空间
// currentObservation 是四个图片拼成一个张量
// 初始化时没有连续的4个图片所以直接将4个一样的图片转为张量后拼接
currentObservation = createObservation(flappyBird.getCurrentImg());
}
@Override
public void reset() {
// 游戏重开
flappyBird.restartGame();
}
/**
* @desc : 返回游戏当前的观测状态
* @auth : tyf
* @date : 2023-10-20 15:24:46
*/
@Override
public NDList getObservation() {
return this.currentObservation;
}
/**
* @desc : 返回游戏定义的动作空间,有智能体根据探索策略自动选择动作
* @auth : tyf
* @date : 2023-10-20 15:25:19
*/
@Override
public ActionSpace getActionSpace() {
return this.actionSpace;
}
/**
* @return
* @desc : 执行一个步骤,传入智能体根据探索策略选择的一个动作,这一步要结合游戏
* @auth : tyf
* @date : 2023-10-20 15:27:12
*/
@Override
public Step step(NDList action, boolean b) {
// 首先驱动游戏,获取游戏输出
// env 的 action 转为游戏的操作 action
int act = action.singletonOrThrow().getInt(0);
FlappyBird.GameValue gameOut = flappyBird.step(act);
// 上一个观测状态
NDList preObservation = currentObservation;
// 新的观测状态
currentObservation = createObservation(gameOut.getImage());
// 根据游戏返回的信息修改奖励,加快训练速度
float reward = 0.0f;
// 小鸟得分,通过的水管个数
long score = gameOut.getScore();
// 奖励设置:
// 游戏结束(碰撞到天花板、底板、水管) -1
// 位置不合适(未来可能碰撞到前面的水管,但此时没有死亡) 0.1
// 位置合适(未来不会碰撞到前面的水管) 0.3
// 得到一分(通过一个水管) 1
// CRASH_MAYBE, // 未来可能会撞击
// CRASH_NO, // 未来不会撞击
// GET_SCORE, // 小鸟通过水管
// GET_SCORE_MAYBE, // 小鸟处于水管中间
if(gameOut.isDone()){
reward = -1.0f;
}else{
// 未来可能会撞击
if(gameOut.getType().equals(FlappyBird.TYPE.CRASH_MAYBE)){
reward = 0.1f;
}
// 未来不会撞击
else if(gameOut.getType().equals(FlappyBird.TYPE.CRASH_NO)){
reward = 0.2f;
}
// 小鸟通过水管
else if(gameOut.getType().equals(FlappyBird.TYPE.GET_SCORE)){
reward = 0.8f;
}
// 小鸟处于水管中间
else if(gameOut.getType().equals(FlappyBird.TYPE.GET_SCORE_MAYBE)){
reward = 1f;
}
}
// 是否游戏结束
boolean done = gameOut.isDone();
// 封装一个步骤并加入校验回放池
Step s = new GameStep(manager.newSubManager(), actionSpace, preObservation, currentObservation, action, reward, done);
replayBuffer.addStep(s);
// 如果游戏结束则重开
if (done) {
flappyBird.restartGame();
}
trainStep++;
System.out.println("trainStep:"+trainStep+", action:"+act+", reward:"+reward+", done:"+done+", score:"+score);
return s;
}
@Override
public Step[] getBatch() {
return replayBuffer.getBatch();
}
@Override
public void close() {
}
/**
* @desc : 运行环境,这个是训练的入口,调用step的入口
* @auth : tyf
* @date : 2023-10-20 15:13:54
*/
@Override
public float runEnvironment(RlAgent agent, boolean training) {
// 选择一个action
NDList action = agent.chooseAction(this, training);
// 执行
this.step(action,training);
// 间隔一段进行checklpoint
if(trainStep%10000==0){
// 更新神经网络
agent.trainBatch(getBatch());
try {
model.save(Paths.get("src/main/resources/model"), "dqn-" + (trainStep%1000));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return 0;
}
/**
* @desc : 将每一帧图片转为神经网络输入张量,每4个连续帧合并成一个输入张量
* @auth : tyf
* @date : 2023-10-20 09:59:40
*/
public NDList createObservation(BufferedImage currentImg) {
// 将图片转为灰度 80*80的张量
NDArray observation = NDImageUtils.toTensor(NDImageUtils.resize(ImageFactory.getInstance().fromImage(currentImg).toNDArray(NDManager.newBaseManager(), Image.Flag.GRAYSCALE),80,80));
if (imgQueue.isEmpty()) {
for (int i = 0; i < 4; i++) {
imgQueue.offer(observation);
}
return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
} else {
imgQueue.remove();
imgQueue.offer(observation);
NDArray[] buf = new NDArray[4];
int i = 0;
for (NDArray nd : imgQueue) {
buf[i++] = nd;
}
return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
}
}
/**
* @desc : 封装一个步骤,加入校验回放池
* @auth : tyf
* @date : 2023-10-20 15:32:16
*/
static final class GameStep implements Step{
private final NDManager manager;
private final ActionSpace actionSpace;
private final NDList preObservation;
private final NDList postObservation;
private final NDList action;
private final float reward;
private final boolean done;
public GameStep(NDManager manager, ActionSpace actionSpace,NDList preObservation, NDList postObservation, NDList action, float reward, boolean done) {
this.manager = manager;
this.actionSpace = actionSpace;
this.preObservation = preObservation;
this.postObservation = postObservation;
this.action = action;
this.reward = reward;
this.done = done;
}
@Override
public NDList getPreObservation() {
preObservation.attach(manager);
return preObservation;
}
@Override
public NDList getPostObservation() {
return postObservation;
}
@Override
public NDList getAction() {
return action;
}
@Override
public ActionSpace getPostActionSpace() {
return actionSpace;
}
@Override
public NDArray getReward() {
return manager.create(reward);
}
@Override
public boolean isDone() {
return done;
}
@Override
public void close() {
this.manager.close();
}
}
}
创建游戏环境的时候需要设置神经网络也就是初始化 trainer,接下来就是定义总的 step 部署然后定期 checkpoint 就行,另外智能体 Agent 中通过 trainBatch()函数从状态池中获取一批装填用来更新网络参数,总体训练如下:
/**
* @desc : 训练
* @auth : tyf
* @date : 2023-10-20 14:40:43
*/
public static void train() throws Exception{
// 先设置网络结构
Model model = Model.newInstance("QNetwork");
// conv -> conv -> conv -> fc -> fc
SequentialBlock block = new SequentialBlock()
.add(Conv2d.builder().setKernelShape(new Shape(8, 8)).optStride(new Shape(4, 4)).optPadding(new Shape(3, 3)).setFilters(4).build())
.add(Activation::relu)
.add(Conv2d.builder().setKernelShape(new Shape(4, 4)).optStride(new Shape(2, 2)).setFilters(32).build())
.add(Activation::relu)
.add(Conv2d.builder().setKernelShape(new Shape(3, 3)).optStride(new Shape(1, 1)).setFilters(64).build())
.add(Activation::relu)
.add(Blocks.batchFlattenBlock())
.add(Linear.builder().setUnits(512).build())
.add(Activation::relu)
.add(Linear.builder().setUnits(2).build());
model.setBlock(block);
// 网络训练和更新配置
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss())
.optOptimizer(Adam.builder().optLearningRateTracker(Tracker.fixed(1e-6f)).build())
.addEvaluator(new Accuracy())
.optInitializer(new NormalInitializer())
.addTrainingListeners(TrainingListener.Defaults.basic());
// 是否使用预训练权重,是的话从本地加载 src/main/resources/model/dqn-trained-nnnn.params
boolean usePreTrained = false;
if(usePreTrained){
model.load(Paths.get("src/main/resources/model"), "dqn-trained");
}
// 创建网络训练器
int batchSize = 32;
Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(batchSize, 4, 80, 80)); // 模型输入尺寸
trainer.notifyListeners(listener -> listener.onTrainingBegin(trainer));
// 设置最大允许的总步数
float steps = 3000000;
// 探索跟踪器,设置探索策略
// 早期大概率是探索,后期大概率是利用
Tracker exploreRate = new LinearTracker.Builder()
.setBaseValue(0.01f) // 设置探索率的初始值
.optSlope(-(0.01f - 0.00001f) / steps) // 探索率的下降速率
.optMinValue(0.00001f) // 探索率一直下降,直到最小值
.build();
// 张量管理器
NDManager manager = NDManager.newBaseManager();
// 智能体 gamma=0.9f 越大越重视未来奖励 QAgent 也就是djl默认实现的 q-learning 算法
// 如果要实现 DDQN 需要重新其 trainBatch() 函数
RlAgent agent = new EpsilonGreedy(new GameAgent(trainer, 0.9f,manager), exploreRate);
// 设置游戏是否显示
boolean show = true;
// 构造rl校验回放池
int replayBufferSize = 50000;
LruReplayBuffer replayBuffer = new LruReplayBuffer(batchSize, replayBufferSize);
// 构造环境
GameEnv env = new GameEnv(new FlappyBird(show),manager,replayBuffer,model);
// 开始训练,传入智能体
for (int i = 0; i < steps; i++) {
env.runEnvironment(agent,true);
}
}
然后就正常运行了,20w步的时候才学会通过第一个水管。