以下是一个使用 Java 实现的强化学习算法案例,它使用 Q-learning 算法来训练一个智能体来玩迷宫游戏:
```java
import java.util.Arrays;
import java.util.Random;
public class QLearning {
private static final double ALPHA = 0.1; // 学习率
private static final double GAMMA = 0.9; // 折扣因子
private static final double EPSILON = 0.1; // 探索概率
private static final int NUM_EPISODES = 100; // 训练次数
private static final int NUM_ACTIONS = 4; // 动作数
private static final String[] ACTIONS = {"up", "down", "left", "right"}; // 动作列表
public static void main(String[] args) {
int[][] maze = {
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0},
{0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0},
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
};
int[] startState = {1, 1};
int[] goalState = {9, 9};
// 初始化 Q 表
double[][][] qTable = new double[maze.length][maze[0].length][NUM_ACTIONS];
for (double[][] row : qTable) {
for (double[] col : row) {
Arrays.fill(col, 0);
}
}
// 训练 Q-learning 算法
for (int i = 0; i < NUM_EPISODES; i++) {
int[] state = startState.clone();
while (!Arrays.equals(state, goalState)) {
// 选择行动
String action;
if (Math.random() < EPSILON) {
action = ACTIONS[new Random().nextInt(NUM_ACTIONS)];
} else {
action = ACTIONS[argmax(qTable[state[0]][state[1]])];
}
// 执行行动
int[] nextState;
switch (action) {
case "up":
nextState = new int[]{state[0] - 1, state[1]};
break;
case "down":
nextState = new int[]{state[0] + 1, state[1]};
break;
case "left":
nextState = new int[]{state[0], state[1] - 1};
break;
case "right":
nextState = new int[]{state[0], state[1] + 1};
break;
default:
throw new IllegalStateException("Unexpected value: " + action);
}
// 计算奖励
double reward;
if (maze[nextState[0]][nextState[1]] == 1) {
reward = 0;
} else {
reward = 1;
}
// 更新 Q 值
qTable[state[0]][state[1]][Arrays.asList(ACTIONS).indexOf(action)] += ALPHA * (reward + GAMMA * max(qTable[nextState[0]][nextState[1]]) - qTable[state[0]][state[1]][Arrays.asList(ACTIONS).indexOf(action)]);
state = nextState.clone();
}
}
// 测试 Q-learning 算法
int[] state = startState.clone();
while (!Arrays.equals(state, goalState)) {
String action = ACTIONS[argmax(qTable[state[0]][state[1]])];
System.out.println("state: " + Arrays.toString(state) + ", action: " + action);
switch (action) {
case "up":
state = new int[]{state[0] - 1, state[1]};
break;
case "down":
state = new int[]{state[0] + 1, state[1]};
break;
case "left":
state = new int[]{state[0], state[1] - 1};
break;
case "right":
state = new int[]{state[0], state[1] + 1};
break;
default:
throw new IllegalStateException("Unexpected value: " + action);
}
}
System.out.println("state: " + Arrays.toString(state));
}
// 获取最大值的下标
private static int argmax(double[] array) {
int maxIndex = 0;
double maxValue = array[0];
for (int i = 1; i < array.length; i++) {
if (array[i] > maxValue) {
maxIndex = i;
maxValue = array[i];
}
}
return maxIndex;
}
// 获取最大值
private static double max(double[] array) {
double maxValue = array[0];
for (double value : array) {
if (value > maxValue) {
maxValue = value;
}
}
return maxValue;
}
}
```
该代码首先定义了一个迷宫环境,然后使用 Q-learning 算法训练一个智能体来玩迷宫游戏。最后,它输出了智能体在迷宫中的行动轨迹,以及到达目标状态的位置。