强化学习系列12:使用julia训练深度强化模型

1. 介绍

使用flux作为深度学习的框架,入门代码:

using ReinforcementLearning
run(E`JuliaRL_BasicDQN_CartPole`)

显示:

This experiment uses three dense layers to approximate the Q value.
The testing environment is CartPoleEnv.


    Agent and statistic info will be saved to: `/home/runner/work/JuliaReinforcementLearning.github.io/JuliaReinforcementLearning.github.io/checkpoints/JuliaRL_BasicDQN_CartPole_2020_09_30_05_40_03`
    You can also view the tensorboard logs with
    `tensorboard --logdir /home/runner/work/JuliaReinforcementLearning.github.io/JuliaReinforcementLearning.github.io/checkpoints/JuliaRL_BasicDQN_CartPole_2020_09_30_05_40_03/tb_log`
    To load the agent and statistic info:
    ```
    agent = RLCore.load("/home/runner/work/JuliaReinforcementLearning.github.io/JuliaReinforcementLearning.github.io/checkpoints/JuliaRL_BasicDQN_CartPole_2020_09_30_05_40_03", Agent)
    BSON.@load joinpath("/home/runner/work/JuliaReinforcementLearning.github.io/JuliaReinforcementLearning.github.io/checkpoints/JuliaRL_BasicDQN_CartPole_2020_09_30_05_40_03", "stats.bson") total_reward_per_episode time_per_step

根据提示启动tensorboard,如果启动不了,去安装包里面把如下的文件夹删除:
~-nsorboard-1.14.0.dist-info
~b_nightly-2.1.0a20191026.dist-info
~ensorboard-1.13.1.dist-info
~ensorboard-1.14.0.dist-info

在这里插入图片描述
我们也可以尝试其他算法:
JuliaRL_BasicDQN_CartPole
JuliaRL_DQN_CartPole
JuliaRL_PrioritizedDQN_CartPole
JuliaRL_Rainbow_CartPole
JuliaRL_IQN_CartPole
JuliaRL_A2C_CartPole
JuliaRL_A2CGAE_CartPole
JuliaRL_PPO_CartPole

2. 组件介绍

主要是4个组件,定义在一个experiment中:

Agent
Environment
Hook
Stop Condition

执行
run(E`JuliaRL_BasicDQN_CartPole`)等价于执行

experiment     = E`JuliaRL_BasicDQN_CartPole`
agent          = experiment.agent
env            = experiment.env
stop_condition = experiment.stop_condition
hook           = experiment.hook
run(agent, env, stop_condition, hook)

详细的信息如下:

ReinforcementLearningCore.Experiment
├─ agent => ReinforcementLearningCore.Agent
│  ├─ policy => ReinforcementLearningCore.QBasedPolicy
│  │  ├─ learner => ReinforcementLearningZoo.BasicDQNLearner
│  │  │  ├─ approximator => ReinforcementLearningCore.NeuralNetworkApproximator
│  │  │  │  ├─ model => Flux.Chain
│  │  │  │  │  └─ layers
│  │  │  │  │     ├─ 1
│  │  │  │  │     │  └─ Flux.Dense
│  │  │  │  │     │     ├─ W => 128×4 Array{Float32,2}
│  │  │  │  │     │     ├─ b => 128-element Array{Float32,1}
│  │  │  │  │     │     └─ σ => typeof(NNlib.relu)
│  │  │  │  │     ├─ 2
│  │  │  │  │     │  └─ Flux.Dense
│  │  │  │  │     │     ├─ W => 128×128 Array{Float32,2}
│  │  │  │  │     │     ├─ b => 128-element Array{Float32,1}
│  │  │  │  │     │     └─ σ => typeof(NNlib.relu)
│  │  │  │  │     └─ 3
│  │  │  │  │        └─ Flux.Dense
│  │  │  │  │           ├─ W => 2×128 Array{Float32,2}
│  │  │  │  │           ├─ b => 2-element Array{Float32,1}
│  │  │  │  │           └─ σ => typeof(identity)
│  │  │  │  └─ optimizer => Flux.Optimise.ADAM
│  │  │  │     ├─ eta => 0.001
│  │  │  │     ├─ beta
│  │  │  │     │  ├─ 1
│  │  │  │     │  │  └─ 0.9
│  │  │  │     │  └─ 2
│  │  │  │     │     └─ 0.999
│  │  │  │     └─ state => IdDict
│  │  │  │        ├─ ht => 32-element Array{Any,1}
│  │  │  │        ├─ count => 0
│  │  │  │        └─ ndel => 0
│  │  │  ├─ loss_func => typeof(ReinforcementLearningCore.huber_loss)
│  │  │  ├─ γ => 0.99
│  │  │  ├─ batch_size => 32
│  │  │  ├─ min_replay_history => 100
│  │  │  ├─ rng => Random.MersenneTwister
│  │  │  └─ loss => 0.0
│  │  └─ explorer => ReinforcementLearningCore.EpsilonGreedyExplorer
│  │     ├─ ϵ_stable => 0.01
│  │     ├─ ϵ_init => 1.0
│  │     ├─ warmup_steps => 0
│  │     ├─ decay_steps => 500
│  │     ├─ step => 1
│  │     ├─ rng => Random.MersenneTwister
│  │     └─ is_training => true
│  ├─ trajectory => ReinforcementLearningCore.CombinedTrajectory
│  │  ├─ reward => 0-element ReinforcementLearningCore.CircularArrayBuffer{Float32,1}
│  │  ├─ terminal => 0-element ReinforcementLearningCore.CircularArrayBuffer{Bool,1}
│  │  ├─ state => 4×0 view(::ReinforcementLearningCore.CircularArrayBuffer{Float32,2}, :, 1:0) with eltype Float32
│  │  ├─ next_state => 4×0 view(::ReinforcementLearningCore.CircularArrayBuffer{Float32,2}, :, 2:1) with eltype Float32
│  │  ├─ full_state => 4×0 view(::ReinforcementLearningCore.CircularArrayBuffer{Float32,2}, :, 1:0) with eltype Float32
│  │  ├─ action => 0-element view(::ReinforcementLearningCore.CircularArrayBuffer{Int64,1}, 1:0) with eltype Int64
│  │  ├─ next_action => 0-element view(::ReinforcementLearningCore.CircularArrayBuffer{Int64,1}, 2:1) with eltype Int64
│  │  └─ full_action => 0-element view(::ReinforcementLearningCore.CircularArrayBuffer{Int64,1}, 1:0) with eltype Int64
│  ├─ role => DEFAULT_PLAYER
│  └─ is_training => true
├─ env => ReinforcementLearningEnvironments.CartPoleEnv: ReinforcementLearningBase.SingleAgent(),ReinforcementLearningBase.Sequential(),ReinforcementLearningBase.PerfectInformation(),ReinforcementLearningBase.Deterministic(),ReinforcementLearningBase.StepReward(),ReinforcementLearningBase.GeneralSum(),ReinforcementLearningBase.MinimalActionSet(),ReinforcementLearningBase.Observation{Array}()
├─ stop_condition => ReinforcementLearningCore.StopAfterStep
│  ├─ step => 10000
│  ├─ cur => 1
│  └─ progress => ProgressMeter.Progress
├─ hook => ReinforcementLearningCore.ComposedHook
│  └─ hooks
│     ├─ 1
│     │  └─ ReinforcementLearningCore.TotalRewardPerEpisode
│     │     ├─ rewards => 0-element Array{Float64,1}
│     │     └─ reward => 0.0
│     ├─ 2
│     │  └─ ReinforcementLearningCore.TimePerStep
│     │     ├─ times => 0-element ReinforcementLearningCore.CircularArrayBuffer{Float64,1}
│     │     └─ t => 832228407951
│     ├─ 3
│     │  └─ ReinforcementLearningCore.DoEveryNStep
│     │     ├─ f => ReinforcementLearningZoo.var"#136#141"
│     │     ├─ n => 1
│     │     └─ t => 0
│     ├─ 4
│     │  └─ ReinforcementLearningCore.DoEveryNEpisode
│     │     ├─ f => ReinforcementLearningZoo.var"#138#143"
│     │     ├─ n => 1
│     │     └─ t => 0
│     └─ 5
│        └─ ReinforcementLearningCore.DoEveryNStep
│           ├─ f => ReinforcementLearningZoo.var"#140#145"
│           ├─ n => 10000
│           └─ t => 0
└─ description => "    This experiment uses three dense layers to approximate the Q value...

2.1 Agent

Agent接收env,返回action。
数据结构上,主要包含两部分:Policy和Trajectory。Policy实现接收env返回action的功能,Trajectory用来记录数据,并在合适的时候更新策略。来看一个例子

using ReinforcementLearning
using Random
using Flux
rng = MersenneTwister(123) # 这就是个随机数发生器
env = CartPoleEnv(;T = Float32, rng = rng) # 需要定义记录的数据格式
ns, na = length(get_state(env)), length(get_actions(env))
agent = Agent(
    policy = QBasedPolicy(
        learner = BasicDQNLearner(
            approximator = NeuralNetworkApproximator(
                model = Chain(
                    Dense(ns, 128, relu; initW = glorot_uniform(rng)),
                    Dense(128, 128, relu; initW = glorot_uniform(rng)),
                    Dense(128, na; initW = glorot_uniform(rng)),
                ) |> cpu,
                optimizer = ADAM(),
            ),
            batch_size = 32,
            min_replay_history = 100,
            loss_func = huber_loss,
            rng = rng,
        ),
        explorer = EpsilonGreedyExplorer(
            kind = :exp,
            ϵ_stable = 0.01,
            decay_steps = 500,
            rng = rng,
        ),
    ),
    trajectory = CircularCompactSARTSATrajectory(
        capacity = 1000,
        state_type = Float32,
        state_size = (ns,),
    ),
)

2.1.1 Policy

最简单的策略莫过于random policy了,然后还有QBasedPolicy,需要一个映射函数叫Q-learner去计算每一个action的value,然后有一个explorer会以一定概率去探索效果不好的action。
在这里插入图片描述

我们可以修改Policy的参数,然后绘制图形看变化:

using Plots

experiment = E`JuliaRL_BasicDQN_CartPole`
experiment.agent.policy.learner.γ = 0.98
hook = TotalRewardPerEpisode()

run(experiment.agent, experiment.env, experiment.stop_condition, hook)
plot(hook.rewards)

在这里插入图片描述

2.1.2 Trajectory

t = CircularCompactSARTSATrajectory(;capacity=3)
push!(t; state=1, action=1, reward=0., terminal=false, next_state=2, next_action=2)
get_trace(t, :state)
empty!(t)

下面是图示:
在这里插入图片描述

2.2 Environment

Env中需要定义获取state、reward和终止的函数:

reset!(env)                  # reset env to the initial state
get_state(env)               # get the state from environment, usually it's a tensor
get_reward(env)              # get the reward since last interaction with environment
get_terminal(env)            # check if the game is terminated or not
env(rand(get_actions(env)))  # feed a random action to the environment

下面是CartPoleEnv的介绍:

# ReinforcementLearningEnvironments.CartPoleEnv

## Traits

| Trait Type        |                                          Value |
|:----------------- | ----------------------------------------------:|
| NumAgentStyle     |        ReinforcementLearningBase.SingleAgent() |
| DynamicStyle      |         ReinforcementLearningBase.Sequential() |
| InformationStyle  | ReinforcementLearningBase.PerfectInformation() |
| ChanceStyle       |      ReinforcementLearningBase.Deterministic() |
| RewardStyle       |         ReinforcementLearningBase.StepReward() |
| UtilityStyle      |         ReinforcementLearningBase.GeneralSum() |
| ActionStyle       |   ReinforcementLearningBase.MinimalActionSet() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Array}() |

## Actions

ReinforcementLearningBase.DiscreteSpace{UnitRange{Int64}}(1:2)

## Players

  * `DEFAULT_PLAYER`

## Current Player

`DEFAULT_PLAYER`

## Is Environment Terminated?

No

2.3 Hook

hook用来收集实验数据、修改agent和env的参数。下面是常用的两个hook:

experiment = E`JuliaRL_BasicDQN_CartPole`
hook = TotalRewardPerEpisode()
run(experiment.agent, experiment.env, experiment.stop_condition, hook)
plot(hook.rewards)

2.3.1 自定义hook

DoEveryNStep() do t, agent, env
    with_logger(lg) do
        @info "training" loss = agent.policy.learner.loss
    end
end

2.4 Stop condition

两个常用的停止条件为StopAfterStep和StopAfterEpisode

3. 自定义例子

运行的代码为:

run(agent,
    env,
    stop_condition
    DoEveryNStep(EVALUATION_FREQ) do t, agent, env
        Flux.testmode!(agent)
        run(agent, env, eval_stop_condition, eval_hook)
        Flux.trainmode!(agent)
    end)

我们这里举一个例子,买彩票的蒙特卡洛树搜索
在这里插入图片描述

3.1 需要定义的内容

get_actions(env::YourEnv)
get_state(env::YourEnv)
get_reward(env::YourEnv)
get_terminal(env::YourEnv)
reset!(env::YourEnv)
(env::YourEnv)(action)

3.2 定义env

using ReinforcementLearning

mutable struct LotteryEnv <: AbstractEnv
    reward::Union{Nothing, Int}
end
# 下面是需要实现的方法,RL.Base是ReinforcementLearningBase的缩写
# 对架构中的方法进行了封装,因此能看到调用reward获取env.reward的操作
# 需要定义完框架后才能去定义env
RLBase.get_actions(env::LotteryEnv) = (:PowerRich, :MegaHaul, nothing)
RLBase.get_reward(env::LotteryEnv) = env.reward
RLBase.get_state(env::LotteryEnv) = !isnothing(env.reward)
RLBase.get_terminal(env::LotteryEnv) = !isnothing(env.reward)
RLBase.reset!(env::LotteryEnv) = env.reward = nothing
# 定义policy,接受action,设置reward
function (env::LotteryEnv)(action)
    action = e
    if action == :PowerRich
        env.reward = rand() < 0.01 ? 100_000_000 : -10
    elseif action == :MegaHaul
        env.reward = rand() < 0.05 ? 1_000_000 : -10
    else
        env.reward = 0
    end
end
env = LotteryEnv(nothing)

返回

## Traits

| Trait Type        |                Value |
|:----------------- | --------------------:|
| NumAgentStyle     |        SingleAgent() |
| DynamicStyle      |         Sequential() |
| InformationStyle  | PerfectInformation() |
| ChanceStyle       |      Deterministic() |
| RewardStyle       |         StepReward() |
| UtilityStyle      |         GeneralSum() |
| ActionStyle       |   MinimalActionSet() |
| DefaultStateStyle | Observation{Array}() |

## Actions

(:PowerRich, :MegaHaul, nothing)

## Players

  * `DEFAULT_PLAYER`

## Current Player

`DEFAULT_PLAYER`

## Is Environment Terminated?

No

接着来测试运行一下

hook = TotalRewardPerEpisode()
run(
           Agent(
               ;policy = RandomPolicy(env),
               trajectory = VectCompactSARTSATrajectory(
                   state_type=Bool,
                   action_type=Any,
                   reward_type=Int,
                   terminal_type=Bool,
               ),
           ),
           LotteryEnv(),
           StopAfterEpisode(1_000),
           hook
       )
TotalRewardPerEpisode([-10.0, -10.0, -10.0, -10.0, -10.0, 0.0, 1.0e8, -10.0, -10.0, -10.0  …  -10.0, -10.0, 0.0, -10.0, 0.0, 0.0, -10.0, -10.0, -10.0, 0.0], 0.0)
println(sum(hook.rewards) / 1_000)
14993.58
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值