再一次被本科生吊打!清华本科生开源强化学习平台天授!

Java面试笔试面经、Java技术每天学习一点

Java面试

关注不迷路

开源最前线(ID:OpenSourceTop) 猿妹编译

项目地址:https://github.com/thu-ml/tianshou

深度强化学习(deep RL)近年来取得了令人瞩目的进步,就说2020年至今,就有多个深度学习框架相继开源,清华的Jittor、旷视的MegEngine、华为的Mindspore,国内首个开源深度学习框架PaddlePaddle等。近日,清华大学又新开源了强化学习平台天授。

这个项目的主要创建者是Jiayi Weng与Minghao Zhang,他们都是清华的本科生。Jiayi Weng今年6月份本科毕业,在此之前作为本科研究者与清华大学苏航、朱军等老师开展强化学习领域的相关研究。Minghao Zhang目前是清华大学软件学院的本科二年级学生,同时还修了数学专业。

是的,你没听错,就是本科生,是不是感觉又一次被吊打了呢?不过这两个人实力可都是杠杠的,就说Jiayi Weng,小升初的暑假就开始写代码。高二作为全国青少年信息学奥林匹克竞赛(NOI)选手进入省队。高中时期就开始钻研微积分、线性代数,大二上学期就加入了朱军教授领导的TSAIL实验室,大三暑假期间更是去到加拿大图灵奖获得者Bengio教授的实验室,深入开展了RL和NLP的研究。

而且这个项目取名为“天授”,这一词语源自《史记》,意为“取天所授而非学自人类”,刻画了强化学习通过与环境进行交互自主学习,而不需要像监督学习一样需要大量人类标注数据。

天授是什么?

天授(Tianshou)是纯基于 PyTorch 代码的强化学习框架,与目前现有基于TensorFlow 的强化学习库不同,天授的类继承并不复杂,API 也不是很繁琐。支持的 RL 算法包括:

  • Policy Gradient (PG)

  • Deep Q-Network (DQN)

  • Double DQN (DDQN) with n-step returns

  • Advantage Actor-Critic (A2C)

  • Deep Deterministic Policy Gradient (DDPG)

  • Proximal Policy Optimization (PPO)

  • Twin Delayed DDPG (TD3)

  • Soft Actor-Critic (SAC)

为什么要选择天授

速度快:天授是一个轻量级的高速强化学习平台。是在笔记本电脑(i7-8750H + GTX1060)上进行的测试。在CartPole-v0任务上,它仅需3秒就可以训练一个倒立摆(CartPole)。

上图为天授与各大知名 RL 开源平台在 CartPole 与 Pendulum 环境下的速度对比。所有代码均在配置为 i7-8750H + GTX1060 的同一台笔记本电脑上进行测试。

可复现性:天授有其单元测试。每一次单元测试除了基本功能的测试之外,还包括针对所有算法的完整训练过程,也就是说一旦有一个算法没办法 train 出来结果,单元测试不能通过。据我们所知,得益于天授快速的训练机制,天授是目前唯一一个采用这种标准进行单元测试的强化学习框架

模块化:天授将算法分解为四个部分:

  • init:策略初始化。

  • process_fn:处理函数,从回放缓存中处理数据。

  • call:根据观测值计算操作

  • learn:从给定数据包中学习

接口灵活:用户可以定制各种各样的 training 方法。提供示例,方便用户根据自己的需要进行二次开发

如何使用天授?

这是深度Q网络的一个示例。您还可以在test / discrete / test_dqn.py运行完整的脚本。

首先,导入一些相关的软件包:

import gym, torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts

定义一些超参数:

task = 'CartPole-v0'
lr = 1e-3
gamma = 0.9
n_step = 3
eps_train, eps_test = 0.1, 0.05
epoch = 10
step_per_epoch = 1000
collect_per_step = 10
target_freq = 320
batch_size = 64
train_num, test_num = 8, 100
buffer_size = 20000
writer = SummaryWriter('log/dqn')  # tensorboard is also supported!

环境配置:

# you can also try with SubprocVectorEnv
train_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)])

建立网络:

class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(*[
            nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape))
        ])
    def forward(self, s, state=None, info={}):
        if not isinstance(s, torch.Tensor):
            s = torch.tensor(s, dtype=torch.float)
        batch = s.shape[0]
        logits = self.model(s.view(batch, -1))
        return logits, state

env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=lr)

设置策略和收集器:

policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
    use_target_network=True, target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
test_collector = ts.data.Collector(policy, test_envs)

训练:

result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
    test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
    test_fn=lambda e: policy.set_eps(eps_test),
    stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
print(f'Finished training! Use {result["duration"]}')

保存/加载策略:

torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))

以35帧率观察模型表现:

collector = ts.data.Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.close()

查看保存在tensorboard中的结果:

tensorboard --logdir log/dqn

你可以在Github和PyPI上找到天授的最新版本和其他资料。最后附上相关地址:

PyPI:https://pypi.org/project/tianshou/

Github天授主页:https://github.com/thu-ml/tianshou

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
好的,以下是一个简单的C语言五子棋小游戏程序: ```c #include <stdio.h> #include <stdlib.h> #define SIZE 15 char board[SIZE][SIZE]; void init_board() { int i, j; for (i = 0; i < SIZE; i++) { for (j = 0; j < SIZE; j++) { board[i][j] = '.'; } } } void print_board() { int i, j; printf("\n"); for (i = 0; i < SIZE; i++) { printf("%2d ", i); for (j = 0; j < SIZE; j++) { printf("%c ", board[i][j]); } printf("\n"); } printf(" "); for (i = 0; i < SIZE; i++) { printf("%d ", i); } printf("\n"); } int check_win(int x, int y) { char c = board[x][y]; int i, j, cnt; // check row cnt = 0; for (j = 0; j < SIZE; j++) { if (board[x][j] == c) { cnt++; if (cnt == 5) return 1; } else { cnt = 0; } } // check column cnt = 0; for (i = 0; i < SIZE; i++) { if (board[i][y] == c) { cnt++; if (cnt == 5) return 1; } else { cnt = 0; } } // check diagonal cnt = 0; for (i = x, j = y; i >= 0 && j >= 0; i--, j--) { if (board[i][j] == c) { cnt++; if (cnt == 5) return 1; } else { break; } } for (i = x+1, j = y+1; i < SIZE && j < SIZE; i++, j++) { if (board[i][j] == c) { cnt++; if (cnt == 5) return 1; } else { break; } } // check anti-diagonal cnt = 0; for (i = x, j = y; i >= 0 && j < SIZE; i--, j++) { if (board[i][j] == c) { cnt++; if (cnt == 5) return 1; } else { break; } } for (i = x+1, j = y-1; i < SIZE && j >= 0; i++, j--) { if (board[i][j] == c) { cnt++; if (cnt == 5) return 1; } else { break; } } return 0; } int main() { int x, y, turn = 0; init_board(); while (1) { print_board(); if (turn % 2 == 0) { printf("Player 1's turn (X)\n"); } else { printf("Player 2's turn (O)\n"); } printf("Enter the position (x, y): "); scanf("%d%d", &x, &y); if (x < 0 || x >= SIZE || y < 0 || y >= SIZE) { printf("Invalid position!\n"); continue; } if (board[x][y] != '.') { printf("This position is already occupied!\n"); continue; } board[x][y] = (turn % 2 == 0 ? 'X' : 'O'); if (check_win(x, y)) { print_board(); if (turn % 2 == 0) { printf("Player 1 (X) wins!\n"); } else { printf("Player 2 (O) wins!\n"); } break; } turn++; } return 0; } ``` 这个程序实现了一个基本的五子棋游戏,先让用户输入坐标,然后判断是否合法,如果合法则在棋盘上落子,并且判断是否胜利。如果胜利则输出获胜者,并结束游戏。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值