Java QLearning算法实现

Java QLearning算法论文复现

将论文中的QLearning算法用java语言进行了复现。

原文是python实现的,这里贴个地址,方便跳转。
强化学习之Q-Learning(python实现)

我的java版本:


    public static void main(String[] args) throws Exception {
        //状态动作数量
        int state_num = 6, action_num = 6;
        //学习率
        double gamma = 0.8;
        //训练次数
        int epochNumber = 200;
        //迭代终止条件
        int conditionStop = 5;

        //初始化地图
        int[][] R = Init(state_num, action_num);
        double[][] Q = new double[state_num][action_num];
        for (int epoch = 0; epoch < epochNumber; epoch++) { //训练次数
//            System.out.println("第" + epoch + "轮迭代开始" );
            Random random = new Random();
            int s = random.nextInt(state_num); //状态选择
//                System.out.println("当前状态为:" + s);
            boolean loop = true;
            while (loop) {
                //返回奖励列表中非负奖励的动作的索引列下标
                int a = randomAction(R[s]);
//                    System.out.println("当前状态" + s + "下,选择的动作为:" + a);
                //返回当前状态a中的该行的Q表的最大值
                double qMax = Arrays.stream(Q[a]).max().orElseThrow(()-> new NoSuchElementException("No value present"));
//                    System.out.println("Q表第" + a + "行的最大值为" + qMax);
                Q[s][a] = R[s][a] + qMax * gamma;
                if (s == conditionStop) {
                    //  现在的当前状态已变成状态5,已经到了目标状态,因此已完成了一个更新片段。
                    loop = false;
                } else {
                    s = a;
                }
//                    System.out.println("循环结束,此时状态更新为" + s);
            }

        }
        //归一化处理
        for (int i = 0; i < state_num; i++) {
            for (int j = 0; j < action_num; j++) {
                Q[i][j] = Q[i][j] / 5;
            }
        }
        /**
         * 下图表示的是经过Q QQ-L e a r n i n g Learning算法学习之后得到的最终的状态转移示意图,其
         * 中每个带有箭头的线上标明这个动作对应的即时收益。
         * 所以不管在哪个状态下,只要利用贪心策略找即时收益最大的行为就可以走出房间。
         */
        System.out.println("");
        int[][] newQ = new int[state_num][action_num];
        for (int i = 0; i < state_num; i++) {
            for (int j = 0; j < action_num; j++) {
                newQ[i][j] = (int)Math.round(Q[i][j]);
            }
        }
        for (int i = 0; i < state_num; i++) {
            for (int j = 0; j < action_num; j++) {
                System.out.print(R[i][j] + "\t\t");
            }
            System.out.println();
        }
        System.out.println();
        for (int i = 0; i < state_num; i++) {
            for (int j = 0; j < action_num; j++) {
                System.out.print(newQ[i][j] + "\t\t");
            }
            System.out.println();
        }
        //测试

        while (true) {
            System.out.println("请输入起点:");
            Scanner sc = new Scanner(System.in);
            int start = sc.nextInt();
            while (start != conditionStop) {
                System.out.print(start + " => ");
                double[] arrs = Q[start];
                double max = Double.MIN_VALUE;
                int next = -1;
                for (int i = 0; i < arrs.length; i++) {
                    if (arrs[i] > max) {
                        max = arrs[i];
                        next = i;
                    }
                }
                start = next;
            }
            System.out.println(conditionStop);
            if (sc.nextLine().equals("n")) break;
        }

    }

    private static int randomAction(int[] nums) {
        ArrayList<Integer> list = new ArrayList<>();
        for (int i = 0; i < nums.length; i++) {
            if (nums[i] >= 0) list.add(i);
        }
        Random random = new Random();
        if (!list.isEmpty()) {
            int randomIndex = random.nextInt(list.size());
            return list.get(randomIndex);
        } else {
            // Handle the case when all elements are negative.
            return -1; // or any other appropriate value
        }
    }

    private  static int[][] Init(int state_num, int action_num) {
        int[][] maze = new int[state_num][action_num];
        for (int i = 0; i < state_num; i++) {
            for (int j = 0; j < action_num; j++) {
                maze[i][j] = -1;
            }
        }
        maze[0][4] = 0;
        maze[1][3] = 0;
        maze[2][3] = 0;
        maze[3][1] = 0;
        maze[3][2] = 0;
        maze[3][4] = 0;
        maze[4][0] = 0;
        maze[4][3] = 0;
        maze[5][1] = 0;
        maze[5][4] = 0;
        maze[1][5] = 100;
        maze[4][5] = 100;
        maze[5][5] = 100;
        return maze;
    }

}
  • 8
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是一个使用 Java 实现的强化学习算法案例,它使用 Q-learning 算法来训练一个智能体来玩迷宫游戏: ```java import java.util.Arrays; import java.util.Random; public class QLearning { private static final double ALPHA = 0.1; // 学习率 private static final double GAMMA = 0.9; // 折扣因子 private static final double EPSILON = 0.1; // 探索概率 private static final int NUM_EPISODES = 100; // 训练次数 private static final int NUM_ACTIONS = 4; // 动作数 private static final String[] ACTIONS = {"up", "down", "left", "right"}; // 动作列表 public static void main(String[] args) { int[][] maze = { {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0}, {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}, {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}, {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}, {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}, {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}, {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}, {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}, {0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, }; int[] startState = {1, 1}; int[] goalState = {9, 9}; // 初始化 Q 表 double[][][] qTable = new double[maze.length][maze[0].length][NUM_ACTIONS]; for (double[][] row : qTable) { for (double[] col : row) { Arrays.fill(col, 0); } } // 训练 Q-learning 算法 for (int i = 0; i < NUM_EPISODES; i++) { int[] state = startState.clone(); while (!Arrays.equals(state, goalState)) { // 选择行动 String action; if (Math.random() < EPSILON) { action = ACTIONS[new Random().nextInt(NUM_ACTIONS)]; } else { action = ACTIONS[argmax(qTable[state[0]][state[1]])]; } // 执行行动 int[] nextState; switch (action) { case "up": nextState = new int[]{state[0] - 1, state[1]}; break; case "down": nextState = new int[]{state[0] + 1, state[1]}; break; case "left": nextState = new int[]{state[0], state[1] - 1}; break; case "right": nextState = new int[]{state[0], state[1] + 1}; break; default: throw new IllegalStateException("Unexpected value: " + action); } // 计算奖励 double reward; if (maze[nextState[0]][nextState[1]] == 1) { reward = 0; } else { reward = 1; } // 更新 Q 值 qTable[state[0]][state[1]][Arrays.asList(ACTIONS).indexOf(action)] += ALPHA * (reward + GAMMA * max(qTable[nextState[0]][nextState[1]]) - qTable[state[0]][state[1]][Arrays.asList(ACTIONS).indexOf(action)]); state = nextState.clone(); } } // 测试 Q-learning 算法 int[] state = startState.clone(); while (!Arrays.equals(state, goalState)) { String action = ACTIONS[argmax(qTable[state[0]][state[1]])]; System.out.println("state: " + Arrays.toString(state) + ", action: " + action); switch (action) { case "up": state = new int[]{state[0] - 1, state[1]}; break; case "down": state = new int[]{state[0] + 1, state[1]}; break; case "left": state = new int[]{state[0], state[1] - 1}; break; case "right": state = new int[]{state[0], state[1] + 1}; break; default: throw new IllegalStateException("Unexpected value: " + action); } } System.out.println("state: " + Arrays.toString(state)); } // 获取最大值的下标 private static int argmax(double[] array) { int maxIndex = 0; double maxValue = array[0]; for (int i = 1; i < array.length; i++) { if (array[i] > maxValue) { maxIndex = i; maxValue = array[i]; } } return maxIndex; } // 获取最大值 private static double max(double[] array) { double maxValue = array[0]; for (double value : array) { if (value > maxValue) { maxValue = value; } } return maxValue; } } ``` 该代码首先定义了一个迷宫环境,然后使用 Q-learning 算法训练一个智能体来玩迷宫游戏。最后,它输出了智能体在迷宫中的行动轨迹,以及到达目标状态的位置。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值