DDPG 开源项目使用教程
1. 项目的目录结构及介绍
DDPG/
├── agents/
│ ├── __init__.py
│ ├── ddpg_agent.py
│ └── ou_noise.py
├── models/
│ ├── __init__.py
│ ├── actor.py
│ └── critic.py
├── replay_buffer.py
├── utils.py
├── config.py
├── main.py
├── README.md
└── requirements.txt
- agents/: 包含代理(agent)相关的文件,如
ddpg_agent.py
定义了DDPG代理的类,ou_noise.py
定义了Ornstein-Uhlenbeck噪声过程。 - models/: 包含神经网络模型的定义,如
actor.py
和critic.py
。 - replay_buffer.py: 定义了经验回放缓冲区。
- utils.py: 包含一些实用工具函数。
- config.py: 项目的配置文件。
- main.py: 项目的启动文件。
- README.md: 项目说明文档。
- requirements.txt: 项目依赖的Python包列表。
2. 项目的启动文件介绍
main.py
是项目的启动文件,负责初始化环境、代理和训练过程。以下是 main.py
的主要内容:
import argparse
from agents.ddpg_agent import DDPGAgent
from utils import make_env
def main():
parser = argparse.ArgumentParser(description='DDPG Agent')
parser.add_argument('--env_name', type=str, default='Pendulum-v0', help='Environment name')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
parser.add_argument('--num_episodes', type=int, default=1000, help='Number of episodes')
parser.add_argument('--render', action='store_true', help='Render the environment')
args = parser.parse_args()
env = make_env(args.env_name, args.seed)
agent = DDPGAgent(env)
for episode in range(args.num_episodes):
state = env.reset()
done = False
while not done:
if args.render:
env.render()
action = agent.act(state)
next_state, reward, done, _ = env.step(action)
agent.remember(state, action, reward, next_state, done)
agent.learn()
state = next_state
if __name__ == "__main__":
main()
- 解析命令行参数: 使用
argparse
模块解析命令行参数,包括环境名称、随机种子、训练集数和是否渲染环境。 - 初始化环境: 使用
make_env
函数初始化环境。 - 初始化代理: 创建
DDPGAgent
实例。 - 训练循环: 在每个 episode 中,重置环境,执行动作,存储经验,并进行学习。
3. 项目的配置文件介绍
config.py
文件包含了项目的配置参数,如学习率、折扣因子、缓冲区大小等。以下是 config.py
的主要内容:
# Hyperparameters
LEARNING_RATE_ACTOR = 1e-4
LEARNING_RATE_CRITIC = 1e-3
GAMMA = 0.99
TAU = 1e-3
BUFFER_SIZE = int(1e6)
BATCH_SIZE = 64
# Environment
ENV_NAME = 'Pendulum-v0'
SEED = 0
# Training
NUM_EPISODES = 1000
RENDER = False
- 超参数: 包括演员和评论家的学习率 (
LEARNING_RATE_ACTOR
,LEARNING_RATE_CRITIC
)、折扣因子 (GAMMA
)、软更新参数 (TAU
)、缓冲区大小 (BUFFER_SIZE
) 和批次大小 (BATCH_SIZE
)。 - 环境参数: 环境名称 (
ENV_NAME
) 和随机种子 (SEED
)。 - 训练参数: 训练集数