译自强化学习工具包文档,加入自己的见解
用Q-Learning和Sarsa解决网格世界环境。
配置和规则如下:
- 网格世界是5x5的有边界的格子,有4个可能的动作(North = 1, South = 2, East = 3, West = 4);
- 代理初始位置位于[2, 1]处的格子;
- 代理到达终点蓝色格子将获得奖励+10;
- 可以从[2, 4]跳到[4, 4],这样会获得奖励+5;
- 代理会被黑色障碍格子阻挡;
- 所有其它的动作将会使奖励-1。
创建网格世界环境
创建基本的网格世界环境:
env = rlPredefinedEnv("BasicGridWorld"); % 预制的环境
创建复位函数,这里位置[1, 1]处状态编号为1,沿着列增加,故初始位置状态编号为2:
env.ResetFcn = @() 2;
固定随机数生成器种子以方便复现:
rng(0)
创建Q-Learning代理
首先根据环境的观测空间和动作空间创建Q表:
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
qTable = rlTable(obsInfo, actInfo);
用rlQValueFunction逼近器对象创建Q值函数:
qFcnAppx = rlQValueFunction(qTable, obsInfo, actInfo);
创建Q-Learning代理:
qAgent = rlQAgent(qFcnAppx);
修改代理的部分超参数:
qAgent.AgentOptions.EpsilonGreedyExploration.Epsilon = .04; % ε-贪心探索的ε值
qAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 0.01; % 学习率
训练Q-Learning代理
指定训练参数:
trainOpts = rlTrainingOptions;
% 每回合最多持续50步
trainOpts.MaxStepsPerEpisode = 50;
% 最多训练200回合
trainOpts.MaxEpisodes= 200;
% 当代理在30个连续回合内获得的平均累积奖励大于11时,停止训练。
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 11;
trainOpts.ScoreAveragingWindowLength = 30;
开始训练:
trainingStats = train(qAgent,env,trainOpts);
训练过程:
验证Q-Learning结果
可视化环境并配置参数使得能显示代理轨迹:
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
在训练环境运行对代理的模拟:
sim(qAgent,env)
从代理的轨迹可看出代理成功发现了从[2, 4]跳到[4, 4]的特殊奖励。
创建SARSA代理
流程一致,仅需将创建Q-Learning代理的语句替换成SARSA代理的,使用同样的Q值函数和同样的超参数配置:
sarsaAgent = rlSARSAAgent(qFcnAppx);