QLearning算法
Java QLearning算法论文复现
将论文中的QLearning算法用java语言进行了复现。
原文是python实现的,这里贴个地址,方便跳转。
强化学习之Q-Learning(python实现)
我的java版本:
public static void main(String[] args) throws Exception {
//状态动作数量
int state_num = 6, action_num = 6;
//学习率
double gamma = 0.8;
//训练次数
int epochNumber = 200;
//迭代终止条件
int conditionStop = 5;
//初始化地图
int[][] R = Init(state_num, action_num);
double[][] Q = new double[state_num][action_num];
for (int epoch = 0; epoch < epochNumber; epoch++) { //训练次数
// System.out.println("第" + epoch + "轮迭代开始" );
Random random = new Random();
int s = random.nextInt(state_num); //状态选择
// System.out.println("当前状态为:" + s);
boolean loop = true;
while (loop) {
//返回奖励列表中非负奖励的动作的索引列下标
int a = randomAction(R[s]);
// System.out.println("当前状态" + s + "下,选择的动作为:" + a);
//返回当前状态a中的该行的Q表的最大值
double qMax = Arrays.stream(Q[a]).max().orElseThrow(()-> new NoSuchElementException("No value present"));
// System.out.println("Q表第" + a + "行的最大值为" + qMax);
Q[s][a] = R[s][a] + qMax * gamma;
if (s == conditionStop) {
// 现在的当前状态已变成状态5,已经到了目标状态,因此已完成了一个更新片段。
loop = false;
} else {
s = a;
}
// System.out.println("循环结束,此时状态更新为" + s);
}
}
//归一化处理
for (int i = 0; i < state_num; i++) {
for (int j = 0; j < action_num; j++) {
Q[i][j] = Q[i][j] / 5;
}
}
/**
* 下图表示的是经过Q QQ-L e a r n i n g Learning算法学习之后得到的最终的状态转移示意图,其
* 中每个带有箭头的线上标明这个动作对应的即时收益。
* 所以不管在哪个状态下,只要利用贪心策略找即时收益最大的行为就可以走出房间。
*/
System.out.println("");
int[][] newQ = new int[state_num][action_num];
for (int i = 0; i < state_num; i++) {
for (int j = 0; j < action_num; j++) {
newQ[i][j] = (int)Math.round(Q[i][j]);
}
}
for (int i = 0; i < state_num; i++) {
for (int j = 0; j < action_num; j++) {
System.out.print(R[i][j] + "\t\t");
}
System.out.println();
}
System.out.println();
for (int i = 0; i < state_num; i++) {
for (int j = 0; j < action_num; j++) {
System.out.print(newQ[i][j] + "\t\t");
}
System.out.println();
}
//测试
while (true) {
System.out.println("请输入起点:");
Scanner sc = new Scanner(System.in);
int start = sc.nextInt();
while (start != conditionStop) {
System.out.print(start + " => ");
double[] arrs = Q[start];
double max = Double.MIN_VALUE;
int next = -1;
for (int i = 0; i < arrs.length; i++) {
if (arrs[i] > max) {
max = arrs[i];
next = i;
}
}
start = next;
}
System.out.println(conditionStop);
if (sc.nextLine().equals("n")) break;
}
}
private static int randomAction(int[] nums) {
ArrayList<Integer> list = new ArrayList<>();
for (int i = 0; i < nums.length; i++) {
if (nums[i] >= 0) list.add(i);
}
Random random = new Random();
if (!list.isEmpty()) {
int randomIndex = random.nextInt(list.size());
return list.get(randomIndex);
} else {
// Handle the case when all elements are negative.
return -1; // or any other appropriate value
}
}
private static int[][] Init(int state_num, int action_num) {
int[][] maze = new int[state_num][action_num];
for (int i = 0; i < state_num; i++) {
for (int j = 0; j < action_num; j++) {
maze[i][j] = -1;
}
}
maze[0][4] = 0;
maze[1][3] = 0;
maze[2][3] = 0;
maze[3][1] = 0;
maze[3][2] = 0;
maze[3][4] = 0;
maze[4][0] = 0;
maze[4][3] = 0;
maze[5][1] = 0;
maze[5][4] = 0;
maze[1][5] = 100;
maze[4][5] = 100;
maze[5][5] = 100;
return maze;
}
}