目录
PTAN简介
ptan是一个开源的RL封装包(Github地址),用于封装常用的RL代码,以便提高复用性,减少开发量。安装方法如下:
- 从pypi安装: pip install ptan
- 从github安装:pip install pip install git+https://github.com/Shmuma/ptan.git
ptan中常用的封装类包括:
- Agent:根据输入的Observation得到action
- Experience:得到agent和环境交互的信息
Agent
Agent包装类内部的工作流程如下:
ptan中已经封装好了一些常用的Agent类:DQNAgent,PolicyAgent。如果想自己自定义,可以继承基类BaseAgent
DNQAgent
主要用于dqn族的算法,Net输出的是离散型的actions。使用方法如下:
import torch.nn as nn
import torch.nn.functional as F
import torch
import ptan
#### 构建一个简单模型
class DQNModel(nn.Module):
def __init__(self, n_actions):
super().__init__()
self.n_actions = n_actions
def forward(self, x):
return torch.eye(x.size()[0], self.n_actions)
net = DQNModel(n_actions=5)
x = torch.zeros(3, 9)
print("input is:\n", x)
print("output is:\n", net(x))
-- 输出 --
input is:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]])
output is:
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0.