DJL 强化学习(附带 FlappyBird示例)

经过我的测试 Java 下面做强化学习有以下几个方案,Deeplearning4j(后面的版本已经移除了rl4j库官方也不维护了,做一些小Work还是可以的),DJL(亚马逊开源,目前这个在 Java 深度学习方向是比较全面的),另外就是强化学习仿真平台  gym、mujoco 这些都可以在 Java 平台找到使用方案。

首先是 Deeplearning4j 里面的 rl4j 项目可以做一些基本的强化学习 work,内置了 DQN、Q-learning、A3C等等基础算法:

下面是一个 Deeplearning4j基础示例:

		// q-learning 超参
		QLearning.QLConfiguration QL_CONFIG =
				new QLearning.QLConfiguration(
						123,   	//Random seed
						1000,	//Max step Every epoch 批次下最大执行的步数
						200*1000, //Max step            总执行的部署
						100*1000, //Max size of experience replay 记忆数据
						400,    //size of batches
						100,   //target update (hard) 每10次更新一次参数
						0,     //num step noop warmup   步数从0开始
						0.01,  //reward scaling
						0.9,  //gamma
						1.0,  //td-error clipping
						0.1f,  //min epsilon
						100,  //num step for eps greedy anneal
						false   //double DQN
				);

		// dqn 网络
		DQNFactoryStdDense.Configuration DQN_NET =
				DQNFactoryStdDense.Configuration.builder()
						.updater(new Adam(0.01))
						.numLayer(5)
						.numHiddenNodes(16)
						.build();


		// 游戏交互体,封装了 step reward 等,游戏是否结束等等,实现 MDP接口
		GameMDP mdp = new GameMDP();
		QLearningDiscreteDense dqn = new QLearningDiscreteDense(mdp, DQN_NET, QL_CONFIG, new DataManager());

		// 定义 checkpoint 保存
		// 训练过程不需要判断 isDone 或者参数更新了,这些封装在rlf4里面
		DQNPolicy pol = dqn.getPolicy();
		dqn.train();
		pol.save("game.policy");
		mdp.close();

另外就是正经项目需要使用的 DJL 了,流程是继承官方的 ENV 接口用于和环境交互,然后实现智能体 Agent 并定义网络:

比如实现一个 PPO 算法,定义智能体:

package algorithm.ppo;

import ai.djl.engine.Engine;
import ai.djl.modality.rl.ActionSpace;
import ai.djl.modality.rl.agent.RlAgent;
import ai.djl.modality.rl.env.RlEnv;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.GradientCollector;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.translate.Batchifier;
import algorithm.CommonParameter;
import utils.ActionSampler;
import utils.Helper;

import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;


public class PPOAgent implements RlAgent {

    private Random random;
    private Trainer policyTrainer;
    private Trainer valueTrainer;
    private float rewardDiscount;
    private Batchifier batchifier;

    public PPOAgent(Random random, Trainer policyTrainer, Trainer valueTrainer, float rewardDiscount) {
        this.random = random;
        this.policyTrainer = policyTrainer;
        this.valueTrainer = valueTrainer;
        this.rewardDiscount = rewardDiscount;
        this.batchifier = Batchifier.STACK;
    }

    @Override
    public NDList chooseAction(RlEnv env, boolean training) {
        NDList[] inputs = buildInputs(env.getObservation());
        NDArray actionScores =
                policyTrainer.evaluate(batchifier.batchify(inputs)).singletonOrThrow().squeeze(-1);
        int action;
        if (training) {
            action = ActionSampler.sampleMultinomial(actionScores, random);
        } else {
            action = Math.toIntExact(actionScores.argMax().getLong());
        }
        ActionSpace actionSpace = env.getActionSpace();
        return actionSpace.get(action);
    }

    @Override
    public void trainBatch(RlEnv.Step[] batchSteps) {
        TrainingListener.BatchData batchData =
                new TrainingListener.BatchData(null, new ConcurrentHashMap<>(), new ConcurrentHashMap<>());
        NDList[] preObservations = buildBatchPreObservation(batchSteps);
        NDList[] actions = buildBatchAction(batchSteps);
        NDList[] rewards = buildBatchReward(batchSteps);
        boolean[] dones = buildBatchDone(batchSteps);

        NDList policyOutput = policyTrainer.evaluate(batchifier.batchify(preObservations));
        NDArray distribution = Helper.gather(policyOutput.singletonOrThrow().duplicate(), batchifier.batchify(actions).singletonOrThrow().toIntArray());

        NDList valueOutput = valueTrainer.evaluate(batchifier.batchify(preObservations));
        NDArray values = valueOutput.singletonOrThrow().duplicate();
        NDList estimates = estimateAdvantage(values.duplicate(), batchifier.batchify(rewards).singletonOrThrow(), dones);
        NDArray expectedReturns = estimates.get(0);
        NDArray advantages = estimates.get(1);

        // update critic
        NDList valueOutputUpdated = valueTrainer.forward(batchifier.batchify(preObservations));
        NDArray valuesUpdated = valueOutputUpdated.singletonOrThrow();
        NDArray lossCritic = (expectedReturns.sub(valuesUpdated)).square().mean();
        try (GradientCollector collector = Engine.getInstance().newGradientCollector()) {
            collector.backward(lossCritic);
        }

        // update policy
        NDList policyOutputUpdated = policyTrainer.forward(batchifier.batchify(preObservations));
        NDArray distributionUpdated = Helper.gather(policyOutputUpdated.singletonOrThrow(), batchifier.batchify(actions).singletonOrThrow().toIntArray());
        NDArray ratios = distributionUpdated.div(distribution);

        NDArray surr1 = ratios.mul(advantages);
        NDArray surr2 = ratios.clip(PPOParameter.RATIO_LOWER_BOUND, PPOParameter.RATIO_UPPER_BOUND).mul(advantages);
        NDArray lossActor = surr1.minimum(surr2).mean().neg();

        try (GradientCollector collector = Engine.getInstance().newGradientCollector()) {
            collector.backward(lossActor);
        }
//        policyTrainer.notifyListeners(listener -> listener.onTrainingBatch(policyTrainer, batchData));
//        valueTrainer.notifyListeners(listener -> listener.onTrainingBatch(valueTrainer, batchData));
    }

    private NDList[] buildInputs(NDList observation) {
        return new NDList[]{observation};
    }

    public NDList[] buildBatchPreObservation(RlEnv.Step[] batchSteps) {
        NDList[] result = new NDList[batchSteps.length];
        for (int i = 0; i < batchSteps.length; i++) {
            result[i] = batchSteps[i].getPreObservation();
        }
        return result;
    }

    public NDList[] buildBatchAction(RlEnv.Step[] batchSteps) {
        NDList[] result = new NDList[batchSteps.length];
        for (int i = 0; i < batchSteps.length; i++) {
            result[i] = batchSteps[i].getAction();
        }
        return result;
    }

    public NDList[] buildBatchPostObservation(RlEnv.Step[] batchSteps) {
        NDList[] result = new NDList[batchSteps.length];
        for (int i = 0; i < batchSteps.length; i++) {
            result[i] = batchSteps[i].getPostObservation();
        }
        return result;
    }

    public NDList[] buildBatchReward(RlEnv.Step[] batchSteps) {
        NDList[] result = new NDList[batchSteps.length];
        for (int i = 0; i < batchSteps.length; i++) {
            result[i] = new NDList(batchSteps[i].getReward().expandDims(0));
        }
        return result;
    }

    public boolean[] buildBatchDone(RlEnv.Step[] batchSteps) {
        boolean[] resultData = new boolean[batchSteps.length];
        for (int i = 0; i < batchSteps.length; i++) {
            resultData[i] = batchSteps[i].isDone();
        }
        return resultData;
    }

    private NDList estimateAdvantage(NDArray values, NDArray rewards, boolean[] masks) {
        NDManager manager = rewards.getManager();
        NDArray deltas = manager.create(rewards.getShape());
        NDArray advantages = manager.create(rewards.getShape());

        float prevValue = 0;
        float prevAdvantage = 0;
        for (int i = (int) rewards.getShape().get(0) - 1; i >= 0; i--) {
            NDIndex index = new NDIndex(i);
            int mask = masks[i] ? 0 : 1;
            deltas.set(index, rewards.get(i).add(CommonParameter.GAMMA * prevValue * mask).sub(values.get(i)));
            advantages.set(index, deltas.get(i).add(CommonParameter.GAMMA * CommonParameter.GAE_LAMBDA * prevAdvantage * mask));

            prevValue = values.getFloat(i);
            prevAdvantage = advantages.getFloat(i);
        }

        NDArray expected_returns = values.add(advantages);
        NDArray advantagesMean = advantages.mean();
        NDArray advantagesStd = advantages.sub(advantagesMean).pow(2).sum().div(advantages.size() - 1).sqrt();
        advantages = advantages.sub(advantagesMean).div(advantagesStd);

        return new NDList(expected_returns, advantages);
    }

    private NDArray getSample(NDManager subManager, NDArray array, int[] index) {
        Shape shape = Shape.update(array.getShape(), 0, index.length);
        NDArray sample = subManager.zeros(shape, array.getDataType());
        for (int i = 0; i < index.length; i++) {
            sample.set(new NDIndex(i), array.get(index[i]));
        }
        return sample;
    }
}

定义网络以及ENV并启动训练:

    public static TrainingResult runExample(String[] args) throws IOException {
        Arguments arguments = new Arguments().parseArgs(args);
        if (arguments == null) {
            return null;
        }

        Engine.getInstance().setRandomSeed(0);
        int epoch = 500;
        int batchSize = 64;
        int replayBufferSize = 2048;
        int gamesPerEpoch = 128;
        // Validation is deterministic, thus one game is enough
        int validationGamesPerEpoch = 1;
        float rewardDiscount = 0.9f;

        if (arguments.getLimit() != Long.MAX_VALUE) {
            gamesPerEpoch = Math.toIntExact(arguments.getLimit());
        }

        Random random = new Random(0);
        NDManager mainManager = NDManager.newBaseManager();
        CartPole env = new CartPole(mainManager, random, batchSize, replayBufferSize);
        int stateSpaceDim = (int) env.getObservation().singletonOrThrow().getShape().get(0);
        int actionSpaceDim = env.getActionSpace().size();
        Model policyModel = Model.newInstance("discrete_policy_model");
        BaseModelBlock policyNet = new DiscretePolicyModelBlock(actionSpaceDim, PPOParameter.POLICY_MODEL_HIDDEN_SIZE);
        policyModel.setBlock(policyNet);
        Model valueModel = Model.newInstance("critic_value_model");
        BaseModelBlock valueNet = new CriticValueModelBlock(PPOParameter.CRITIC_MODEL_HIDDEN_SIZE);
        valueModel.setBlock(valueNet);

        DefaultTrainingConfig policyConfig = setupTrainingConfig();
        DefaultTrainingConfig valueConfig = setupTrainingConfig();
        Trainer policyTrainer = policyModel.newTrainer(policyConfig);
        Trainer valueTrainer = valueModel.newTrainer(valueConfig);

        policyTrainer.initialize(new Shape(batchSize, 4));
        valueTrainer.initialize(new Shape(batchSize, 4));
        policyTrainer.notifyListeners(listener -> listener.onTrainingBegin(policyTrainer));
        valueTrainer.notifyListeners(listener -> listener.onTrainingBegin(valueTrainer));

        PPOAgent agent = new PPOAgent(random, policyTrainer, valueTrainer, rewardDiscount);
        for (int i = 0; i < epoch; i++) {

            int episode = 0;
            int size = 0;
            while (size < replayBufferSize) {
                episode++;
                float result = env.runEnvironment(agent, true);
                size += (int) result;
                System.out.println("[" + episode + "]train:" + result);
            }

            for (int j = 0; j < gamesPerEpoch; j++) {
                RlEnv.Step[] batchSteps = env.getBatch();
                agent.trainBatch(batchSteps);
                policyTrainer.step();
                valueTrainer.step();
//                System.out.println("train[" + i + "-" + j + "]:" + result);
            }

            for (int j = 0; j < validationGamesPerEpoch; j++) {
                float result = env.runEnvironment(agent, false);
                System.out.println("test:" + result);
            }
        }

        policyTrainer.notifyListeners(listener -> listener.onTrainingEnd(policyTrainer));
        valueTrainer.notifyListeners(listener -> listener.onTrainingEnd(valueTrainer));

        return policyTrainer.getTrainingResult();
    }

FlappyBrid

下面是使用 DJL 训练 DQN 算法实现 FlappyBrid 智能体的例子。
第一步是实现游戏,游戏代码从github上找的,需要改变得部分就是每执行一帧,将游戏得状态返回(小鸟得位置、是否碰撞、整个游戏画面、是否通过水管)用于奖励修改以及神经网络输入:

    /**
     *   @desc : 执行一个步骤,返回游戏状态
     *   @auth : tyf
     *   @date : 2023-10-20  10:32:30
    */
    public GameValue step(int action) {
        if (action == 1) {
            bird.birdFlap();
        }
        // 返回一个游戏状态
        GameValue gameValue = stepFrame();
        if (this.withGraphics) {
            try {
                // 将步骤执行后的游戏画面返回,后续作为神经网络的输入
                gameValue.setImage(currentImg);
                Thread.sleep(FPS);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            // 画面渲染
            repaint();
        }
        return gameValue;
    }

    /**
	*   @desc : 游戏功能测试
	*   @auth : tyf
	*   @date : 2023-10-20  10:30:26
	*/
	public static void checkGame(){
		FlappyBird flappyBird = new FlappyBird(true);
		while (true){
			// 动作是 1 和 0 随机生成
			int action = new Random().nextInt(10)%2==0?1:0;
			// 每走一步返回一个游戏状态
			// 这个状态会交给env用于判断奖励或者游戏结束
			FlappyBird.GameValue v = flappyBird.step(action);
			// 获取每一帧游戏画面,作为神经网络的输入
			BufferedImage img = v.getImage();
			// 游戏结束则重启
			if(v.isDone()){
				flappyBird.restartGame();
			}

			System.out.println("当前状态:"+v);
		}

	}


后续就是定义环境和智能体,Env 的作用是将环境和智能体进行结合,将游戏和强化学习框架进行结合,一般调用 Env 的 step() 函数 => 游戏的 step() 函数,再判断奖励设置、游戏是否结束等等:

package org.env;

import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.modality.rl.ActionSpace;
import ai.djl.modality.rl.LruReplayBuffer;
import ai.djl.modality.rl.ReplayBuffer;
import ai.djl.modality.rl.agent.RlAgent;
import ai.djl.modality.rl.env.RlEnv;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import org.game.FlappyBird;

import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayDeque;
import java.util.Queue;


/**
 *   @desc : 封装游戏和训练环境
 *   @auth : tyf
 *   @date : 2023-10-20  11:55:25
*/
public class GameEnv implements RlEnv {


    // 游戏实体
    private FlappyBird flappyBird;

    // 训练环境
    private final NDManager manager; // 张量管理器
    private final LruReplayBuffer replayBuffer; // 经验回访池
    private NDList currentObservation; // 游戏当前帧的观测空间
    private ActionSpace actionSpace; // 动作空间

    // 观测空间缓存,每次存放4个,4个张量使用 stack 拼接成一个张量后输入网络
    private final Queue<NDArray> imgQueue = new ArrayDeque<>(4);

    // 记录游戏和训练检测以及中间临时变量参数
    public static int trainStep = 0;

    // 神经网络
    public static Model model;

    /**
     *   @desc : 游戏环境创建
     *   @auth : tyf
     *   @date : 2023-10-20  14:45:08
    */
    public GameEnv(FlappyBird flappyBird, NDManager manager, LruReplayBuffer replayBuffer,Model model) {

        // 游戏实体
        this.flappyBird = flappyBird;
        // 张量管理器
        this.manager = manager;
        // 经验回放池
        this.replayBuffer = replayBuffer;
        // 神经网络
        this.model = model;

        // 初始化动作空间,这里游戏只有两个动作
        actionSpace = new ActionSpace();
        actionSpace.add(new NDList(manager.create(new int[]{0,1}))); // DO_NOTHING 不飞
        actionSpace.add(new NDList(manager.create(new int[]{1,0}))); // FLAP 飞

        // 初始化游戏当前观测空间
        // currentObservation 是四个图片拼成一个张量
        // 初始化时没有连续的4个图片所以直接将4个一样的图片转为张量后拼接
        currentObservation = createObservation(flappyBird.getCurrentImg());

    }

    @Override
    public void reset() {
        // 游戏重开
        flappyBird.restartGame();
    }

    /**
     *   @desc : 返回游戏当前的观测状态
     *   @auth : tyf
     *   @date : 2023-10-20  15:24:46
    */
    @Override
    public NDList getObservation() {
        return this.currentObservation;
    }

    /**
     *   @desc : 返回游戏定义的动作空间,有智能体根据探索策略自动选择动作
     *   @auth : tyf
     *   @date : 2023-10-20  15:25:19
    */
    @Override
    public ActionSpace getActionSpace() {
        return this.actionSpace;
    }


    /**
     * @return
     * @desc : 执行一个步骤,传入智能体根据探索策略选择的一个动作,这一步要结合游戏
     * @auth : tyf
     * @date : 2023-10-20  15:27:12
     */
    @Override
    public Step step(NDList action, boolean b) {

        // 首先驱动游戏,获取游戏输出
        // env 的 action 转为游戏的操作 action
        int act = action.singletonOrThrow().getInt(0);
        FlappyBird.GameValue gameOut = flappyBird.step(act);

        // 上一个观测状态
        NDList preObservation = currentObservation;
        // 新的观测状态
        currentObservation = createObservation(gameOut.getImage());

        // 根据游戏返回的信息修改奖励,加快训练速度
        float reward = 0.0f;

        // 小鸟得分,通过的水管个数
        long score = gameOut.getScore();

        // 奖励设置:
        // 游戏结束(碰撞到天花板、底板、水管) -1
        // 位置不合适(未来可能碰撞到前面的水管,但此时没有死亡) 0.1
        // 位置合适(未来不会碰撞到前面的水管) 0.3
        // 得到一分(通过一个水管) 1

        //        CRASH_MAYBE, // 未来可能会撞击
        //        CRASH_NO, // 未来不会撞击
        //        GET_SCORE, // 小鸟通过水管
        //        GET_SCORE_MAYBE, // 小鸟处于水管中间

        if(gameOut.isDone()){
            reward = -1.0f;
        }else{
            // 未来可能会撞击
            if(gameOut.getType().equals(FlappyBird.TYPE.CRASH_MAYBE)){
                reward = 0.1f;
            }
            // 未来不会撞击
            else if(gameOut.getType().equals(FlappyBird.TYPE.CRASH_NO)){
                reward = 0.2f;
            }
            // 小鸟通过水管
            else if(gameOut.getType().equals(FlappyBird.TYPE.GET_SCORE)){
                reward = 0.8f;
            }
            // 小鸟处于水管中间
            else if(gameOut.getType().equals(FlappyBird.TYPE.GET_SCORE_MAYBE)){
                reward = 1f;
            }
        }

        // 是否游戏结束
        boolean done = gameOut.isDone();

        // 封装一个步骤并加入校验回放池
        Step s = new GameStep(manager.newSubManager(), actionSpace, preObservation, currentObservation, action, reward, done);
        replayBuffer.addStep(s);

        // 如果游戏结束则重开
        if (done) {
            flappyBird.restartGame();
        }

        trainStep++;
        System.out.println("trainStep:"+trainStep+", action:"+act+", reward:"+reward+", done:"+done+", score:"+score);

        return s;
    }

    @Override
    public Step[] getBatch() {
        return replayBuffer.getBatch();
    }

    @Override
    public void close() {

    }

    /**
     * @desc : 运行环境,这个是训练的入口,调用step的入口
     * @auth : tyf
     * @date : 2023-10-20  15:13:54
     */
    @Override
    public float runEnvironment(RlAgent agent, boolean training) {

        // 选择一个action
        NDList action = agent.chooseAction(this, training);

        // 执行
        this.step(action,training);

        // 间隔一段进行checklpoint
        if(trainStep%10000==0){
            // 更新神经网络
            agent.trainBatch(getBatch());
            try {
                model.save(Paths.get("src/main/resources/model"), "dqn-" + (trainStep%1000));
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        return 0;
    }

    /**
     *   @desc : 将每一帧图片转为神经网络输入张量,每4个连续帧合并成一个输入张量
     *   @auth : tyf
     *   @date : 2023-10-20  09:59:40
     */
    public NDList createObservation(BufferedImage currentImg) {
        // 将图片转为灰度 80*80的张量
        NDArray observation = NDImageUtils.toTensor(NDImageUtils.resize(ImageFactory.getInstance().fromImage(currentImg).toNDArray(NDManager.newBaseManager(), Image.Flag.GRAYSCALE),80,80));
        if (imgQueue.isEmpty()) {
            for (int i = 0; i < 4; i++) {
                imgQueue.offer(observation);
            }
            return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
        } else {
            imgQueue.remove();
            imgQueue.offer(observation);
            NDArray[] buf = new NDArray[4];
            int i = 0;
            for (NDArray nd : imgQueue) {
                buf[i++] = nd;
            }
            return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
        }
    }


    /**
     *   @desc : 封装一个步骤,加入校验回放池
     *   @auth : tyf
     *   @date : 2023-10-20  15:32:16
    */
    static final class GameStep implements Step{
        private final NDManager manager;
        private final ActionSpace actionSpace;
        private final NDList preObservation;
        private final NDList postObservation;
        private final NDList action;
        private final float reward;
        private final boolean done;
        public GameStep(NDManager manager, ActionSpace actionSpace,NDList preObservation, NDList postObservation, NDList action, float reward, boolean done) {
            this.manager = manager;
            this.actionSpace = actionSpace;
            this.preObservation = preObservation;
            this.postObservation = postObservation;
            this.action = action;
            this.reward = reward;
            this.done = done;
        }
        @Override
        public NDList getPreObservation() {
            preObservation.attach(manager);
            return preObservation;
        }
        @Override
        public NDList getPostObservation() {
            return postObservation;
        }
        @Override
        public NDList getAction() {
            return action;
        }
        @Override
        public ActionSpace getPostActionSpace() {
            return actionSpace;
        }
        @Override
        public NDArray getReward() {
            return manager.create(reward);
        }
        @Override
        public boolean isDone() {
            return done;
        }
        @Override
        public void close() {
            this.manager.close();
        }
    }
}

创建游戏环境的时候需要设置神经网络也就是初始化 trainer,接下来就是定义总的 step 部署然后定期 checkpoint 就行,另外智能体 Agent 中通过 trainBatch()函数从状态池中获取一批装填用来更新网络参数,总体训练如下:

	/**
	 *   @desc : 训练
	 *   @auth : tyf
	 *   @date : 2023-10-20  14:40:43
	*/
	public static void train() throws Exception{

		// 先设置网络结构
		Model model = Model.newInstance("QNetwork");
		// conv -> conv -> conv -> fc -> fc
		SequentialBlock block = new SequentialBlock()
				.add(Conv2d.builder().setKernelShape(new Shape(8, 8)).optStride(new Shape(4, 4)).optPadding(new Shape(3, 3)).setFilters(4).build())
				.add(Activation::relu)
				.add(Conv2d.builder().setKernelShape(new Shape(4, 4)).optStride(new Shape(2, 2)).setFilters(32).build())
				.add(Activation::relu)
				.add(Conv2d.builder().setKernelShape(new Shape(3, 3)).optStride(new Shape(1, 1)).setFilters(64).build())
				.add(Activation::relu)
				.add(Blocks.batchFlattenBlock())
				.add(Linear.builder().setUnits(512).build())
				.add(Activation::relu)
				.add(Linear.builder().setUnits(2).build());
		model.setBlock(block);

		// 网络训练和更新配置
		DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss())
				.optOptimizer(Adam.builder().optLearningRateTracker(Tracker.fixed(1e-6f)).build())
				.addEvaluator(new Accuracy())
				.optInitializer(new NormalInitializer())
				.addTrainingListeners(TrainingListener.Defaults.basic());

		// 是否使用预训练权重,是的话从本地加载  src/main/resources/model/dqn-trained-nnnn.params
		boolean usePreTrained = false;
		if(usePreTrained){
			model.load(Paths.get("src/main/resources/model"), "dqn-trained");
		}

		// 创建网络训练器
		int batchSize = 32;
		Trainer trainer = model.newTrainer(config);
		trainer.initialize(new Shape(batchSize, 4, 80, 80)); // 模型输入尺寸
		trainer.notifyListeners(listener -> listener.onTrainingBegin(trainer));

		// 设置最大允许的总步数
		float steps = 3000000;

		// 探索跟踪器,设置探索策略
		// 早期大概率是探索,后期大概率是利用
		Tracker exploreRate = new LinearTracker.Builder()
				.setBaseValue(0.01f) // 设置探索率的初始值
				.optSlope(-(0.01f - 0.00001f) / steps) // 探索率的下降速率
				.optMinValue(0.00001f) // 探索率一直下降,直到最小值
				.build();

		// 张量管理器
		NDManager manager = NDManager.newBaseManager();

		// 智能体 gamma=0.9f 越大越重视未来奖励 QAgent 也就是djl默认实现的 q-learning 算法
		// 如果要实现 DDQN 需要重新其 trainBatch() 函数
		RlAgent agent = new EpsilonGreedy(new GameAgent(trainer, 0.9f,manager), exploreRate);

		// 设置游戏是否显示
		boolean show = true;

		// 构造rl校验回放池
		int replayBufferSize = 50000;
		LruReplayBuffer replayBuffer = new LruReplayBuffer(batchSize, replayBufferSize);

		// 构造环境
		GameEnv env = new GameEnv(new FlappyBird(show),manager,replayBuffer,model);

		// 开始训练,传入智能体
		for (int i = 0; i < steps; i++) {
			env.runEnvironment(agent,true);
		}


	}

然后就正常运行了,20w步的时候才学会通过第一个水管。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

0x13

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值