tianshou框架(一)

一 下载安装

下载tianhou框架

pip install tianshou

 下载源码https://github.com/thu-ml/tianshou/tree/v0.4.11

运行test文件下的test_dqn.py文件

显示如下

整个框架的如下

 

  • policy:主要是网络的构成及更新的方法
  • collector:收集数据
  • Trainer:训练

二 创建环境

创建环境有多种方法

使用gym创建

import gymnasium as gym

env = gym.make('CartPole-v1')

在tianshou 框架中,支持并行环境,提供一下四种创建环境的方式

  • DummyVectorEnv:最简单的创建方式,使用循环创建
  • SubprocVectorEnv:使用python多进程创建
  • ShmemVectorEnv:使用共享内存
  • RayVectorEnv:不知道啥玩意
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(100)])

对于自定义环境,需要设置

def seed(self, seed):
    np.random.seed(seed)

三 网络创建

自定义网络时需要注意

  • Input:输入numpy.ndarray,torch.Tensor
  • Output:logits,元组,tensor或者在策略转发过程中其他一些有用的变量或结果。这取决于策略类如何处理网络输出,states

四 策略policy

policy = ts.policy.DQNPolicy(
    model=net,
    optim=optim,
    action_space=env.action_space,
    discount_factor=0.9,
    estimation_step=3,
    target_update_freq=320
)

五 Collector

train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)

六 Trainer

总共有三种训练方式

  • OnpolicyTrainer
  • OffpolicyTrainer
  • OfflineTrainer
result = ts.trainer.OffpolicyTrainer(
    policy=policy,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=10, step_per_epoch=10000, step_per_collect=10,
    update_per_step=0.1, episode_per_test=100, batch_size=64,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold
).run()
print(f'Finished training! Use {result["duration"]}')

七 源码解读

下载源码文件,存在以下文件

  • docs:教程文档,说明书
  • example:一些实验的例子
  • test:使用的案列,大部分使用都在里面
  • tianshou:整个框架代码

进入test\discrete\test_dqn.py  里面是对dqn算法的测试

整个甲苯由三个函数组成

  • get_args:算法参数设置
  • test_dqn:主要部分
  • test_pdqn: 使用优先经验回放测试

首先是get_args:

def get_args():
    '''
    参数解析:
    --task: gym环境名称
    --reward-threshold: 奖励阈值
    --seed: 随机种子
    --eps-test: 测试时 epsilon
    --eps-train: 训练时 epsilon
    --buffer-size: 经验回放池大小
    --lr: 学习率
    --gamma: 折扣因子
    --n-step: n步更新
    --target-update-freq: 更新target网络的频率
    --epoch: 训练的回合数
    --step-per-epoch: 每个回合中,采集的轨迹数(及样本数)
    --step-per-collect: 每采集多少个轨迹更新一次网络
    --update-per-step: 每个step-per-collect过后网络更新的次数  更新次数 = step-per-collect * update-per-step
    --batch-size: 批量大小
    --hidden-sizes: 神经网络隐藏层的大小
    --training-num: 训练环境的数量
    --test-num: 测试环境的数量
    --logdir: tensorboard日志文件夹
    --render: 是否渲染环境
    --prioritized-replay: 是否使用优先经验回放
    --alpha: 优先经验回放的alpha
    --beta: 优先经验回放的beta
    --device: 训练设备
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='CartPole-v1')
    parser.add_argument('--reward-threshold', type=float, default=None)
    parser.add_argument('--seed', type=int, default=1626)
    parser.add_argument('--eps-test', type=float, default=0.05)
    parser.add_argument('--eps-train', type=float, default=0.1)
    parser.add_argument('--buffer-size', type=int, default=20000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--gamma', type=float, default=0.9)
    parser.add_argument('--n-step', type=int, default=3)
    parser.add_argument('--target-update-freq', type=int, default=320)
    parser.add_argument('--epoch', type=int, default=20)
    parser.add_argument('--step-per-epoch', type=int, default=10000)
    parser.add_argument('--step-per-collect', type=int, default=10)
    parser.add_argument('--update-per-step', type=float, default=0.1)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument(
        '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
    )
    parser.add_argument('--training-num', type=int, default=10)
    parser.add_argument('--test-num', type=int, default=100)
    parser.add_argument('--logdir', type=str, default='log')
    parser.add_argument('--render', type=float, default=0.)
    parser.add_argument('--prioritized-replay', action="store_true", default=False)
    parser.add_argument('--alpha', type=float, default=0.6)
    parser.add_argument('--beta', type=float, default=0.4)
    parser.add_argument(
        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
    )
    args = parser.parse_known_args()[0]
    return args

接着是test_dqn

def test_dqn(args=get_args()):
    # 创建环境,获取参数
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    if args.reward_threshold is None:
        default_reward_threshold = {"CartPole-v0": 195}
        args.reward_threshold = default_reward_threshold.get(
            args.task, env.spec.reward_threshold
        )
    # train_envs = gym.make(args.task)
    # you can also use tianshou.env.SubprocVectorEnv
    train_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)]
    )
    # test_envs = gym.make(args.task)
    test_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)]
    )
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # Q_param = V_param = {"hidden_sizes": [128]}
    # model
    # 创建网络
    net = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        device=args.device,
        # dueling=(Q_param, V_param),
    ).to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # 策略
    policy = DQNPolicy(
        net,
        optim,
        args.gamma,
        args.n_step,
        target_update_freq=args.target_update_freq,
    )
    # buffer
    if args.prioritized_replay:
        buf = PrioritizedVectorReplayBuffer(
            args.buffer_size,
            buffer_num=len(train_envs),
            alpha=args.alpha,
            beta=args.beta,
        )
    else:
        buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
    # collector 数据收集
    train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # policy.set_eps(1)
    train_collector.collect(n_step=args.batch_size * args.training_num)
    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn')
    writer = SummaryWriter(log_path)
    logger = TensorboardLogger(writer)

    def save_best_fn(policy):
        # 保存参数
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        # 停止训练条件
        return mean_rewards >= args.reward_threshold

    def train_fn(epoch, env_step):
        # eps annnealing, just a demo
        # 训练中eps设置
        if env_step <= 10000:
            policy.set_eps(args.eps_train)
        elif env_step <= 50000:
            eps = args.eps_train - (env_step - 10000) / \
                40000 * (0.9 * args.eps_train)
            policy.set_eps(eps)
        else:
            policy.set_eps(0.1 * args.eps_train)

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    # trainer 创建
    result = offpolicy_trainer(
        policy,
        train_collector,
        test_collector,
        args.epoch,
        args.step_per_epoch,
        args.step_per_collect,
        args.test_num,
        args.batch_size,
        update_per_step=args.update_per_step,
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        logger=logger,
    )
    assert stop_fn(result['best_reward'])

    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = gym.make(args.task)
        policy.eval()
        policy.set_eps(args.eps_test)
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        rews, lens = result["rews"], result["lens"]
        print(f"Final reward: {rews.mean()}, length: {lens.mean()}")

  • 6
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值