MATLAB中Reinforcement Learning Toolbox的使用教程1

该博客介绍了如何使用MATLAB的ReinforcementLearningToolbox进行强化学习,分别展示了Q-learning和SARSA算法在二维网格环境中的简单训练实现。通过创建环境、设置奖励和障碍,然后初始化并训练代理,最终展示训练结果和模拟过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

MATLAB提供了Reinforcement Learning Toolbox可以方便地建立二维基础网格环境、设置起点、目标、障碍,以及各种agent模型

1.Q-learning的训练简单实现

%% 布置环境硬件

GW = createGridWorld(6,6);

GW.CurrentState = '[6,1]';

GW.TerminalStates = '[2,5]';

GW.ObstacleStates = ["[2,3]";"[2,4]";"[3,5]";"[4,5]"];

%% 根据障碍设置可否行进

updateStateTranstionForObstacles(GW)

%% 设置reward

nS = numel(GW.States);

nA = numel(GW.Actions);

GW.R = -1*ones(nS,nS,nA);

GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;

%% 生成环境及初始位置

env = rlMDPEnv(GW);

plot(env)

env.ResetFcn = @() 6;

%% Q-learning训练参数初始化

qTable = rlTable(getObservationInfo(env),getActionInfo(env));

tableRep = rlRepresentation(qTable);

tableRep.Options.LearnRate = 1;

agentOpts = rlQAgentOptions;

agentOpts.EpsilonGreedyExploration.Epsilon = .04;

qAgent = rlQAgent(tableRep,agentOpts);

trainOpts = rlTrainingOptions;

trainOpts.MaxStepsPerEpisode = 50;

trainOpts.MaxEpisodes= 200;

trainOpts.StopTrainingCriteria = "AverageReward";

trainOpts.StopTrainingValue = 11;

trainOpts.ScoreAveragingWindowLength = 30;

%% 训练

rng(0)

trainingStats = train(qAgent,env,trainOpts);

%% 结果展示

plot(env)

env.Model.Viewer.ShowTrace = true;

env.Model.Viewer.clearTrace;

sim(qAgent,env)

2.SARSA的训练简单实现

%% 布置环境硬件

GW = createGridWorld(6,6);

GW.CurrentState = '[6,1]';

GW.TerminalStates = '[2,5]';

GW.ObstacleStates = ["[2,3]";"[2,4]";"[3,5]";"[4,5]"];

%% 设置可否行进

updateStateTranstionForObstacles(GW)

%% 设置reward

nS = numel(GW.States);

nA = numel(GW.Actions);

GW.R = -1*ones(nS,nS,nA);

GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;

%% 生成环境及初始位置

env = rlMDPEnv(GW);

plot(env)

env.ResetFcn = @() 6;

%% %% SARSA参数初始化

rng(0)

qTable = rlTable(getObservationInfo(env),getActionInfo(env));

tableRep = rlRepresentation(qTable);

tableRep.Options.LearnRate = 1;

agentOpts = rlSARSAAgentOptions;

agentOpts.EpsilonGreedyExploration.Epsilon = 0.04;

sarsaAgent = rlSARSAAgent(tableRep,agentOpts);

trainOpts = rlTrainingOptions;

trainOpts.MaxStepsPerEpisode = 50;

trainOpts.MaxEpisodes= 200;

trainOpts.StopTrainingCriteria = "AverageReward";

trainOpts.StopTrainingValue = 11;

trainOpts.ScoreAveragingWindowLength = 30;

%% 训练

trainingStats = train(sarsaAgent,env,trainOpts);

%% 结果展示

plot(env)

env.Model.Viewer.ShowTrace = true;

env.Model.Viewer.clearTrace;

sim(sarsaAgent,env)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MC数据局

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值