强化学习-天授平台

文章介绍了作者使用清华大学开源的tianshou强化学习库来快速复现DQN算法的过程,对比了tianshou与parl库的文档和性能,强调了tianshou在快速复现方面的优势。内容涉及参数配置、经验池管理、TensorBoard集成以及对面向对象设计的探讨,以CartPole-v0环境为例进行了演示。
摘要由CSDN通过智能技术生成

强化学习库tianshou——DQN使用

tianshou是清华大学学生开源编写的强化学习库。本人因为一些比赛的原因,有使用到强化学习,但是因为过于紧张与没有尝试快速复现强化学习的代码,并没有获得很好的成绩,故尝试用库进行快速复现。

之前也尝试了parl等库,感觉parl在文档等方面似乎并不如tianshou,性能上作为菜鸟不好评价。tianshou的官方文档也有很久没有更新了,上面有些代码不能运行,用了最新版tianshou的github上的代码案例进行学习,相关注释已经记录。

import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DQNPolicy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer

def get_args():
‘’’
max_epoch:最大允许的训练轮数,有可能没训练完这么多轮就会停止(因为满足了 stop_fn 的条件)

step_per_epoch:每个epoch要更新多少次策略网络

collect_per_step:每次更新前要收集多少帧与环境的交互数据。上面的代码参数意思是,每收集10帧进行一次网络更新

episode_per_test:每次测试的时候花几个rollout进行测试

batch_size:每次策略计算的时候批量处理多少数据

train_fn:在每个epoch训练之前被调用的函数,输入的是当前第几轮epoch和当前用于训练的env一共step了多少次。上面的代码意味着,在每次训练前将epsilon设置成0.1

test_fn:在每个epoch测试之前被调用的函数,输入的是当前第几轮epoch和当前用于训练的env一共step了多少次。上面的代码意味着,在每次测试前将epsilon设置成0.05

stop_fn:停止条件,输入是当前平均总奖励回报(the average undiscounted returns),返回是否要停止训练

writer:天授支持 TensorBoard,可以像下面这样初始化:

:return:
'''</span>
parser <span class="token operator">=</span> argparse<span class="token punctuation">.</span>ArgumentParser<span class="token punctuation">(</span><span class="token punctuation">)</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--task'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">str</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token string">'CartPole-v0'</span><span class="token punctuation">)</span>  <span class="token comment"># 环境名</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--seed'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">1626</span><span class="token punctuation">)</span>  <span class="token comment"># 随机种子</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--eps-test'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">float</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">0.05</span><span class="token punctuation">)</span>  <span class="token comment"># 贪婪策略的比例</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--eps-train'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">float</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">0.1</span><span class="token punctuation">)</span>  <span class="token comment"># 贪婪策略的比例</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--buffer-size'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">20000</span><span class="token punctuation">)</span>  <span class="token comment"># 回放池大小</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--lr'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">float</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">3</span><span class="token punctuation">)</span>  <span class="token comment"># 学习率</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--gamma'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">float</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">0.9</span><span class="token punctuation">)</span>  <span class="token comment"># 衰减率</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--n-step'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span>  <span class="token comment"># 要向前看的步数</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--target-update-freq'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">320</span><span class="token punctuation">)</span>  <span class="token comment"># 目标网络的更新频率,每隔freq次更新一次,0为不使用目标网络</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--epoch'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">)</span>  <span class="token comment"># 世代</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--step-per-epoch'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">)</span>  <span class="token comment"># 每个世代策略网络更新的次数</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--collect-per-step'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">)</span>  <span class="token comment"># 网络更新之前收集的帧数</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--batch-size'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">)</span>  <span class="token comment"># 神经网络批训练大小</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--hidden-sizes'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span>
                    nargs<span class="token operator">=</span><span class="token string">'*'</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">]</span><span class="token punctuation">)</span>  <span class="token comment"># 隐藏层尺寸</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--training-num'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">8</span><span class="token punctuation">)</span>  <span class="token comment"># 学习环境数量</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--test-num'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">int</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">100</span><span class="token punctuation">)</span>  <span class="token comment"># 测试环境数量</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--logdir'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">str</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token string">'log'</span><span class="token punctuation">)</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--render'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">float</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">.</span><span class="token punctuation">)</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--prioritized-replay'</span><span class="token punctuation">,</span>
                    action<span class="token operator">=</span><span class="token string">"store_true"</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>  <span class="token comment"># 优先重播</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--alpha'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">float</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">0.6</span><span class="token punctuation">)</span>  <span class="token comment"># 经验池参数,每轮所有样本进行指数变换的常数</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span><span class="token string">'--beta'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">float</span><span class="token punctuation">,</span> default<span class="token operator">=</span><span class="token number">0.4</span><span class="token punctuation">)</span>  <span class="token comment"># 经验池参数,重要抽样权重的常数,内含一个公式的化简,详细看源码</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span>
    <span class="token string">'--save-buffer-name'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">str</span><span class="token punctuation">,</span>
    default<span class="token operator">=</span><span class="token string">"./expert_DQN_CartPole-v0.pkl"</span><span class="token punctuation">)</span>
parser<span class="token punctuation">.</span>add_argument<span class="token punctuation">(</span>
    <span class="token string">'--device'</span><span class="token punctuation">,</span> <span class="token builtin">type</span><span class="token operator">=</span><span class="token builtin">str</span><span class="token punctuation">,</span>
    default<span class="token operator">=</span><span class="token string">'cuda'</span> <span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">else</span> <span class="token string">'cpu'</span><span class="token punctuation">)</span>
args <span class="token operator">=</span> parser<span class="token punctuation">.</span>parse_known_args<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
<span class="token keyword">return</span> args

def test_dqn(args=get_args()):
env = gym.make(args.task) # 构建env
# 状态纬度
args.state_shape = env.observation_space.shape or env.observation_space.n
# 行动数量
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)

<span class="token comment"># 构建envs,dummyvectorenv运用for实现 subpro运用多进程实现</span>
<span class="token comment"># you can also use tianshou.env.SubprocVectorEnv</span>
train_envs <span class="token operator">=</span> DummyVectorEnv<span class="token punctuation">(</span>
    <span class="token punctuation">[</span><span class="token keyword">lambda</span><span class="token punctuation">:</span> gym<span class="token punctuation">.</span>make<span class="token punctuation">(</span>args<span class="token punctuation">.</span>task<span class="token punctuation">)</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>args<span class="token punctuation">.</span>training_num<span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># test_envs = gym.make(args.task)</span>
test_envs <span class="token operator">=</span> DummyVectorEnv<span class="token punctuation">(</span>
    <span class="token punctuation">[</span><span class="token keyword">lambda</span><span class="token punctuation">:</span> gym<span class="token punctuation">.</span>make<span class="token punctuation">(</span>args<span class="token punctuation">.</span>task<span class="token punctuation">)</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>args<span class="token punctuation">.</span>test_num<span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># seed 设置随机种子 方便复现</span>
np<span class="token punctuation">.</span>random<span class="token punctuation">.</span>seed<span class="token punctuation">(</span>args<span class="token punctuation">.</span>seed<span class="token punctuation">)</span>
torch<span class="token punctuation">.</span>manual_seed<span class="token punctuation">(</span>args<span class="token punctuation">.</span>seed<span class="token punctuation">)</span>
train_envs<span class="token punctuation">.</span>seed<span class="token punctuation">(</span>args<span class="token punctuation">.</span>seed<span class="token punctuation">)</span>
test_envs<span class="token punctuation">.</span>seed<span class="token punctuation">(</span>args<span class="token punctuation">.</span>seed<span class="token punctuation">)</span>
<span class="token comment"># Q_param = V_param = {"hidden_sizes": [128]}</span>
<span class="token comment"># model</span>
<span class="token comment"># 构建神经网络模型,Net是已经被定义好的类</span>
net <span class="token operator">=</span> Net<span class="token punctuation">(</span>args<span class="token punctuation">.</span>state_shape<span class="token punctuation">,</span> args<span class="token punctuation">.</span>action_shape<span class="token punctuation">,</span>
          hidden_sizes<span class="token operator">=</span>args<span class="token punctuation">.</span>hidden_sizes<span class="token punctuation">,</span> device<span class="token operator">=</span>args<span class="token punctuation">.</span>device<span class="token punctuation">,</span>
          <span class="token comment"># dueling=(Q_param, V_param),</span>
          <span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>args<span class="token punctuation">.</span>device<span class="token punctuation">)</span>
<span class="token comment"># 优化器</span>
optim <span class="token operator">=</span> torch<span class="token punctuation">.</span>optim<span class="token punctuation">.</span>Adam<span class="token punctuation">(</span>net<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> lr<span class="token operator">=</span>args<span class="token punctuation">.</span>lr<span class="token punctuation">)</span>
<span class="token comment"># 策略</span>
policy <span class="token operator">=</span> DQNPolicy<span class="token punctuation">(</span>
    net<span class="token punctuation">,</span> optim<span class="token punctuation">,</span> args<span class="token punctuation">.</span>gamma<span class="token punctuation">,</span> args<span class="token punctuation">.</span>n_step<span class="token punctuation">,</span>
    target_update_freq<span class="token operator">=</span>args<span class="token punctuation">.</span>target_update_freq<span class="token punctuation">)</span>
<span class="token comment"># buffer 缓存回放</span>
<span class="token keyword">if</span> args<span class="token punctuation">.</span>prioritized_replay<span class="token punctuation">:</span>
    buf <span class="token operator">=</span> PrioritizedReplayBuffer<span class="token punctuation">(</span>
        args<span class="token punctuation">.</span>buffer_size<span class="token punctuation">,</span> alpha<span class="token operator">=</span>args<span class="token punctuation">.</span>alpha<span class="token punctuation">,</span> beta<span class="token operator">=</span>args<span class="token punctuation">.</span>beta<span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
    buf <span class="token operator">=</span> ReplayBuffer<span class="token punctuation">(</span>args<span class="token punctuation">.</span>buffer_size<span class="token punctuation">)</span>
<span class="token comment"># collector 收集器,主要控制环境与策略的交互</span>
train_collector <span class="token operator">=</span> Collector<span class="token punctuation">(</span>policy<span class="token punctuation">,</span> train_envs<span class="token punctuation">,</span> buf<span class="token punctuation">)</span>
test_collector <span class="token operator">=</span> Collector<span class="token punctuation">(</span>policy<span class="token punctuation">,</span> test_envs<span class="token punctuation">)</span>
<span class="token comment"># policy.set_eps(1)</span>
<span class="token comment"># batchsize是神经网络训练一轮的参数,所以必须要一次性输入batchsize个经验</span>
<span class="token comment"># 也就是 需要隔batchsize步进行训练</span>
<span class="token comment"># 先进行一轮采集,防止经验池为空,采集的参数在后面会被清空,只留下经验池</span>
train_collector<span class="token punctuation">.</span>collect<span class="token punctuation">(</span>n_step<span class="token operator">=</span>args<span class="token punctuation">.</span>batch_size<span class="token punctuation">,</span> no_grad<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
<span class="token comment"># log</span>
log_path <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>args<span class="token punctuation">.</span>logdir<span class="token punctuation">,</span> args<span class="token punctuation">.</span>task<span class="token punctuation">,</span> <span class="token string">'dqn'</span><span class="token punctuation">)</span>
writer <span class="token operator">=</span> SummaryWriter<span class="token punctuation">(</span>log_path<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">save_fn</span><span class="token punctuation">(</span>policy<span class="token punctuation">)</span><span class="token punctuation">:</span>
    torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>policy<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>log_path<span class="token punctuation">,</span> <span class="token string">'policy.pth'</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment"># 停止条件 平均回报大于阈值</span>
<span class="token keyword">def</span> <span class="token function">stop_fn</span><span class="token punctuation">(</span>mean_rewards<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> mean_rewards <span class="token operator">&gt;=</span> env<span class="token punctuation">.</span>spec<span class="token punctuation">.</span>reward_threshold

<span class="token comment"># 学习前调用的函数</span>
<span class="token comment"># 在每个epoch训练之前被调用的函数,输入的是当前第几轮epoch和当前用于训练的env一共step了多少次。</span>
<span class="token comment"># 此处为了实现根据一个世代中的迭代次数改变eps(贪婪策略的比例)</span>
<span class="token keyword">def</span> <span class="token function">train_fn</span><span class="token punctuation">(</span>epoch<span class="token punctuation">,</span> env_step<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment"># eps annnealing, just a demo</span>
    <span class="token keyword">if</span> env_step <span class="token operator">&lt;=</span> <span class="token number">10000</span><span class="token punctuation">:</span>
        policy<span class="token punctuation">.</span>set_eps<span class="token punctuation">(</span>args<span class="token punctuation">.</span>eps_train<span class="token punctuation">)</span>
    <span class="token keyword">elif</span> env_step <span class="token operator">&lt;=</span> <span class="token number">50000</span><span class="token punctuation">:</span>
        eps <span class="token operator">=</span> args<span class="token punctuation">.</span>eps_train <span class="token operator">-</span> <span class="token punctuation">(</span>env_step <span class="token operator">-</span> <span class="token number">10000</span><span class="token punctuation">)</span> <span class="token operator">/</span> \
              <span class="token number">40000</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">0.9</span> <span class="token operator">*</span> args<span class="token punctuation">.</span>eps_train<span class="token punctuation">)</span>
        policy<span class="token punctuation">.</span>set_eps<span class="token punctuation">(</span>eps<span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span>
        policy<span class="token punctuation">.</span>set_eps<span class="token punctuation">(</span><span class="token number">0.1</span> <span class="token operator">*</span> args<span class="token punctuation">.</span>eps_train<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">test_fn</span><span class="token punctuation">(</span>epoch<span class="token punctuation">,</span> env_step<span class="token punctuation">)</span><span class="token punctuation">:</span>
    policy<span class="token punctuation">.</span>set_eps<span class="token punctuation">(</span>args<span class="token punctuation">.</span>eps_test<span class="token punctuation">)</span>

<span class="token comment"># trainer 开始学习</span>
<span class="token comment"># 异策略</span>
result <span class="token operator">=</span> offpolicy_trainer<span class="token punctuation">(</span>
    policy<span class="token punctuation">,</span> train_collector<span class="token punctuation">,</span> test_collector<span class="token punctuation">,</span> args<span class="token punctuation">.</span>epoch<span class="token punctuation">,</span>
    args<span class="token punctuation">.</span>step_per_epoch<span class="token punctuation">,</span> args<span class="token punctuation">.</span>collect_per_step<span class="token punctuation">,</span> args<span class="token punctuation">.</span>test_num<span class="token punctuation">,</span>
    args<span class="token punctuation">.</span>batch_size<span class="token punctuation">,</span> train_fn<span class="token operator">=</span>train_fn<span class="token punctuation">,</span> test_fn<span class="token operator">=</span>test_fn<span class="token punctuation">,</span>
    stop_fn<span class="token operator">=</span>stop_fn<span class="token punctuation">,</span> save_fn<span class="token operator">=</span>save_fn<span class="token punctuation">,</span> writer<span class="token operator">=</span>writer<span class="token punctuation">)</span>

<span class="token keyword">assert</span> stop_fn<span class="token punctuation">(</span>result<span class="token punctuation">[</span><span class="token string">'best_reward'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token keyword">if</span> __name__ <span class="token operator">==</span> <span class="token string">'__main__'</span><span class="token punctuation">:</span>
    pprint<span class="token punctuation">.</span>pprint<span class="token punctuation">(</span>result<span class="token punctuation">)</span>
    <span class="token comment"># Let's watch its performance!</span>
    env <span class="token operator">=</span> gym<span class="token punctuation">.</span>make<span class="token punctuation">(</span>args<span class="token punctuation">.</span>task<span class="token punctuation">)</span>
    policy<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
    policy<span class="token punctuation">.</span>set_eps<span class="token punctuation">(</span>args<span class="token punctuation">.</span>eps_test<span class="token punctuation">)</span>
    collector <span class="token operator">=</span> Collector<span class="token punctuation">(</span>policy<span class="token punctuation">,</span> env<span class="token punctuation">)</span>
    result <span class="token operator">=</span> collector<span class="token punctuation">.</span>collect<span class="token punctuation">(</span>n_episode<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> render<span class="token operator">=</span>args<span class="token punctuation">.</span>render<span class="token punctuation">)</span>
    <span class="token keyword">print</span><span class="token punctuation">(</span>f<span class="token string">'Final reward: {result["rew"]}, length: {result["len"]}'</span><span class="token punctuation">)</span>

<span class="token comment"># save buffer in pickle format, for imitation learning unittest</span>
buf <span class="token operator">=</span> ReplayBuffer<span class="token punctuation">(</span>args<span class="token punctuation">.</span>buffer_size<span class="token punctuation">)</span>
collector <span class="token operator">=</span> Collector<span class="token punctuation">(</span>policy<span class="token punctuation">,</span> test_envs<span class="token punctuation">,</span> buf<span class="token punctuation">)</span>
<span class="token comment">#与环境进行交互,具体为每走一步就判断是否有环境结束,如果环境结束,则将环境所走的步数加入总步数</span>
<span class="token comment">#n_step为其最少的步数,即大于则收集结束</span>
collector<span class="token punctuation">.</span>collect<span class="token punctuation">(</span>n_step<span class="token operator">=</span>args<span class="token punctuation">.</span>buffer_size<span class="token punctuation">)</span>

pickle<span class="token punctuation">.</span>dump<span class="token punctuation">(</span>buf<span class="token punctuation">,</span> <span class="token builtin">open</span><span class="token punctuation">(</span>args<span class="token punctuation">.</span>save_buffer_name<span class="token punctuation">,</span> <span class="token string">"wb"</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

def test_pdqn(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_dqn(args)

if name == main:
test_dqn(get_args())

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187

上面都是脚本式的运作,将参数定义在args里面相对的方便了我们进行参数的修改,涉及到修改逻辑的时候,就会比较复杂一些,在时间充足的情况下,可以考虑使用面向对象的思想。

  将问题分解为问题类与算法类,问题类专注于描述问题,提供问题参数,算法类专注于描述算法,提供算法参数。

  之后,考虑到在本问题中算法类的部分属性依赖于问题类的属性,所以可以在算法类中直接传入一个问题类的实例。这个方法是容易想的。这里换一种思路,提供第三个类,第三个类用于协调问题类与算法类,将在第三个类中进行调用解决和参数交互。我无法说明哪个更好,只能算是自己的一种尝试。

import pprint

import gym
import torch
from tianshou.policy import DQNPolicy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
import numpy as np

class question:
def init(self, gamename=‘CartPole-v0’):
self.gamename = gamename
env = gym.make(gamename)
self.env = env
# 状态的纬度
self.state_shape = env.observation_space.shape or env.observation_space.n
# 行动数量
self.action_shape = env.action_space.shape or env.action_space.n

class program:
def init(self, eps_train, eps_test, epoch, hidden_sizes,
buffer_size, gamma=0.9, n_step=3,
device=‘cpu’, lr=1e-3, target_update_freq=320,
training_num=1, test_num=1, batch_size=64,
step_per_epoch=1, collect_per_step=10):
self.net = None # 网络
self.optim = None # 优化器
self.policy = None # 策略
self.eps_train = eps_train
self.eps_test = eps_test
self.buf = ReplayBuffer(buffer_size) # 缓存大小
self.train_collector = None
self.test_collector = None
self.reward_threshold = None
self.epoch = epoch
self.hidden_sizes = hidden_sizes
self.device = device
self.lr = lr
self.step_per_epoch = step_per_epoch
self.collect_per_step = collect_per_step
self.training_num = training_num
self.test_num = test_num
self.batch_size = batch_size
self.ready = False
self.gamma = gamma
self.n_step = n_step
self.target_update_freq = target_update_freq

    <span class="token comment"># 停止条件 平均回报大于阈值</span>

<span class="token keyword">def</span> <span class="token function">stop_fn</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> reward_threshold<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> <span class="token keyword">lambda</span> mean_rewards<span class="token punctuation">:</span> mean_rewards <span class="token operator">&gt;=</span> reward_threshold

<span class="token comment"># 学习前调用的函数</span>
<span class="token comment"># 在每个epoch训练之前被调用的函数,输入的是当前第几轮epoch和当前用于训练的env一共step了多少次。</span>
<span class="token comment"># 此处为了实现根据一个世代中的迭代次数改变eps(贪婪策略的比例)</span>
<span class="token keyword">def</span> <span class="token function">train_fn</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> epoch<span class="token punctuation">,</span> env_step<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment"># eps annnealing, just a demo</span>
    <span class="token keyword">if</span> env_step <span class="token operator">&lt;=</span> <span class="token number">10000</span><span class="token punctuation">:</span>
        self<span class="token punctuation">.</span>policy<span class="token punctuation">.</span>set_eps<span class="token punctuation">(</span>self<span class="token punctuation">.</span>eps_train<span class="token punctuation">)</span>
    <span class="token keyword">elif</span> env_step <span class="token operator">&lt;=</span> <span class="token number">50000</span><span class="token punctuation">:</span>
        eps <span class="token operator">=</span> self<span class="token punctuation">.</span>eps_train <span class="token operator">-</span> <span class="token punctuation">(</span>env_step <span class="token operator">-</span> <span class="token number">10000</span><span class="token punctuation">)</span> <span class="token operator">/</span> \
              <span class="token number">40000</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">0.9</span> <span class="token operator">*</span> self<span class="token punctuation">.</span>eps_train<span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>policy<span class="token punctuation">.</span>set_eps<span class="token punctuation">(</span>eps<span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span>
        self<span class="token punctuation">.</span>policy<span class="token punctuation">.</span>set_eps<span class="token punctuation">(</span><span class="token number">0.1</span> <span class="token operator">*</span> self<span class="token punctuation">.</span>eps_train<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">test_fn</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> epoch<span class="token punctuation">,</span> env_step<span class="token punctuation">)</span><span class="token punctuation">:</span>
    self<span class="token punctuation">.</span>policy<span class="token punctuation">.</span>set_eps<span class="token punctuation">(</span>self<span class="token punctuation">.</span>eps_test<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">sovle</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">if</span> self<span class="token punctuation">.</span>ready<span class="token punctuation">:</span>
        <span class="token keyword">return</span> offpolicy_trainer<span class="token punctuation">(</span>
            self<span class="token punctuation">.</span>policy<span class="token punctuation">,</span> self<span class="token punctuation">.</span>train_collector<span class="token punctuation">,</span> self<span class="token punctuation">.</span>test_collector<span class="token punctuation">,</span> self<span class="token punctuation">.</span>epoch<span class="token punctuation">,</span>
            self<span class="token punctuation">.</span>step_per_epoch<span class="token punctuation">,</span> self<span class="token punctuation">.</span>collect_per_step<span class="token punctuation">,</span> self<span class="token punctuation">.</span>test_num<span class="token punctuation">,</span>
            self<span class="token punctuation">.</span>batch_size<span class="token punctuation">,</span> train_fn<span class="token operator">=</span>self<span class="token punctuation">.</span>train_fn<span class="token punctuation">,</span> test_fn<span class="token operator">=</span>self<span class="token punctuation">.</span>test_fn<span class="token punctuation">,</span>
            stop_fn<span class="token operator">=</span>self<span class="token punctuation">.</span>stop_fn<span class="token punctuation">(</span>self<span class="token punctuation">.</span>reward_threshold<span class="token punctuation">)</span><span class="token punctuation">)</span>

    <span class="token keyword">else</span><span class="token punctuation">:</span>
        <span class="token keyword">raise</span> Exception<span class="token punctuation">(</span><span class="token string">'unkown error ,maybe you should use init() in class resolve'</span><span class="token punctuation">)</span>

class resolve:
def init(self, que: question, prg: program,seed=None):
self.que = que
self.prg = prg
self.prg.train_envs = DummyVectorEnv(
[lambda: gym.make(self.que.gamename) for _ in range(self.prg.training_num)])
# test_envs = gym.make(args.task)
self.prg.test_envs = DummyVectorEnv(
[lambda: gym.make(self.que.gamename) for _ in range(self.prg.test_num)])
self.set_seed(seed)
# 为了复现,种子设置必须放在环境构建之后,net等其他参数之前
# 因为net在初始化的时候会调用np的随机数,放在之前才能保证复现
self.prg.net = Net(self.que.state_shape, self.que.action_shape,
hidden_sizes=self.prg.hidden_sizes, device=self.prg.device,
# dueling=(Q_param, V_param),
).to(self.prg.device) # 网络
self.prg.optim = torch.optim.Adam(self.prg.net.parameters(), lr=self.prg.lr) # 优化器
self.prg.policy = DQNPolicy(
self.prg.net, self.prg.optim, self.prg.gamma, self.prg.n_step,
target_update_freq=self.prg.target_update_freq) # 策略

    self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>train_collector <span class="token operator">=</span> Collector<span class="token punctuation">(</span>self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>policy<span class="token punctuation">,</span> self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>train_envs<span class="token punctuation">,</span> self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>buf<span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>test_collector <span class="token operator">=</span> Collector<span class="token punctuation">(</span>self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>policy<span class="token punctuation">,</span> self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>test_envs<span class="token punctuation">)</span>

    self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>reward_threshold <span class="token operator">=</span> self<span class="token punctuation">.</span>que<span class="token punctuation">.</span>env<span class="token punctuation">.</span>spec<span class="token punctuation">.</span>reward_threshold
    self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>ready <span class="token operator">=</span> <span class="token boolean">True</span>

<span class="token keyword">def</span> <span class="token function">set_seed</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> seed<span class="token punctuation">)</span><span class="token punctuation">:</span>
    np<span class="token punctuation">.</span>random<span class="token punctuation">.</span>seed<span class="token punctuation">(</span>seed<span class="token punctuation">)</span>
    torch<span class="token punctuation">.</span>manual_seed<span class="token punctuation">(</span>seed<span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>train_envs<span class="token punctuation">.</span>seed<span class="token punctuation">(</span>seed<span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>test_envs<span class="token punctuation">.</span>seed<span class="token punctuation">(</span>seed<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">solve</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> self<span class="token punctuation">.</span>prg<span class="token punctuation">.</span>sovle<span class="token punctuation">(</span><span class="token punctuation">)</span>

def main():
reslv = resolve(question(), program(0.1, 0.05, 10,
[128, 128, 128, 128], 20000,
training_num=8, test_num=100,
step_per_epoch=1000),
seed=1626
)
result=reslv.solve()
pprint.pprint(result)

if name == main:
main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137

其中逻辑为,policy用于描述策略,包括实现探索策略,dqn为概率贪婪策略,其中mask参数可以屏蔽动作。在policy实现forward方法的时候进行探索策略的实现,同时返回预测值和探索动作。

为了复现,种子设置必须放在环境构建之后,net等其他参数之前,因为net在初始化的时候会调用np的随机数,放在之前才能保证复现。这个点我排查了很久。。。

至于更新的策略,没有看懂。看了文档的评测,tianshou速度快过所有的强化学习库,但是功能上还不够完全,多智能体等算法未实现,可能要考虑转向ray了,ray作为一个分布式框架,就不禁让我想起了spark和mllib令我奔溃的日子。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值