本节课你将学到
- 理解深度Q网络(DQN)的核心原理
- 掌握经验回放和目标网络关键技术
- 使用PyTorch实现完整的DQN算法
- 训练能玩Atari游戏的AI智能体
开始之前
环境要求
- Python 3.8+
- PyTorch 2.0+
- Gymnasium[atari] (包含Atari游戏环境)
- OpenCV (图像预处理)
- 必须使用GPU加速训练
前置知识
- 强化学习基础(第33讲)
- 卷积神经网络(第24讲)
- PyTorch张量操作(第22讲)
核心概念
DQN与传统Q-learning的区别
特性 | Q-learning | DQN |
---|---|---|
状态表示 | 离散表格 | 原始像素图像 |
函数近似 | 查表法 | 深度神经网络 |
内存效率 | 低(状态爆炸) | 高 |
适用范围 | 小规模离散问题 | 复杂视觉任务 |
DQN两大关键技术
-
经验回放(Experience Replay)
- 存储转移样本(s,a,r,s’)到记忆库
- 训练时随机采样打破相关性
- 像"回看游戏录像"学习
-
目标网络(Target Network)
- 使用两个网络:在线网络+目标网络
- 定期同步目标网络参数
- 稳定Q值估计(避免"移动靶标"问题)
代码实战
1. 环境预处理
Atari游戏需要特殊的图像预处理
import gymnasium as gym
import numpy as np
import cv2
from collections import deque
import torch
class AtariWrapper:
def __init__(self, env_name, frame_stack=4):
self.env = gym.make(env_name, render_mode='rgb_array')
self.frame_stack = frame_stack
self.frames = deque(maxlen=frame_stack)
def reset(self):
obs, _ = self.env.reset()
processed = self._preprocess(obs)
for _ in range(self.frame_stack):
self.frames.append(processed)
return self._get_state()
def _preprocess(self, frame):
# 转换为灰度图
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
# 裁剪和缩放
frame = frame[34:194, :] # 移除记分牌区域
frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
return frame / 255.0 # 归一化
def _get_state(self):
return np.stack(self.frames, axis=0) # 形状:(4,84,84)
def step(self, action):
obs, reward, terminated, truncated, _ = self.env.step(action)
processed = self._preprocess(obs)
self.frames.append(processed)
next_state = self._get_state()
done = terminated or truncated
return next_state, reward, done
def render(self):
return self.env.render()
def close(self):
self.env.close()
# 创建环境
env = AtariWrapper("PongNoFrameskip-v4")
print("观测空间形状:", env.reset().shape)
# ⚠️ 常见错误1:帧堆叠顺序错误
# 正确顺序应该是[t-3, t-2, t-1, t],确保时间连续性
2. DQN网络架构
import torch.nn as nn
import torch.nn.functional as F
class DQN(nn.Module):
def __init__(self, n_actions):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self.fc1 = nn.Linear(7*7*64, 512)
self.fc2 = nn.Linear(512, n_actions)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1) # 展平
x = F.relu(self.fc1(x))
return self.fc2(x)
# 网络测试
sample_input = torch.randn(1, 4, 84, 84) # 批量大小1,4帧84x84
model = DQN(n_actions=env.env.action_space.n)
print("输出Q值形状:", model(sample_input).shape)
# ⚠️ 常见错误2:维度不匹配
# 确保:
# 1. 输入通道数与帧堆叠数一致
# 2. 输出维度与动作空间大小一致
3. 经验回放实现
import random
from collections import namedtuple
Transition = namedtuple('Transition',
('state', 'action', 'next_state', 'reward', 'done'))
class ReplayMemory:
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, *args):
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
# 初始化记忆库
memory = ReplayMemory(100000)
# 存储示例
state = env.reset()
action = env.env.action_space.sample()
next_state, reward, done = env.step(action)
memory.push(state, action, next_state, reward, done)
print("记忆库当前大小:", len(memory))
4. 完整训练流程
import math
import torch.optim as optim
class DQNAgent:
def __init__(self, n_actions, device):
self.n_actions = n_actions
self.device = device
# 双网络架构
self.policy_net = DQN(n_actions).to(device)
self.target_net = DQN(n_actions).to(device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
# 优化器
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.0001)
# 超参数
self.batch_size = 32
self.gamma = 0.99
self.eps_start = 1.0
self.eps_end = 0.01
self.eps_decay = 10000
self.steps_done = 0
self.target_update = 1000
def select_action(self, state):
sample = random.random()
eps_threshold = self.eps_end + (self.eps_start - self.eps_end) * \
math.exp(-1. * self.steps_done / self.eps_decay)
self.steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
# 使用网络选择最优动作
state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
return self.policy_net(state).max(1)[1].view(1, 1).item()
else:
# 随机探索
return random.randrange(self.n_actions)
def optimize_model(self):
if len(memory) < self.batch_size:
return
# 从记忆库采样
transitions = memory.sample(self.batch_size)
batch = Transition(*zip(*transitions))
# 转换数据为张量
state_batch = torch.tensor(np.array(batch.state), dtype=torch.float32).to(self.device)
action_batch = torch.tensor(batch.action).unsqueeze(1).to(self.device)
reward_batch = torch.tensor(batch.reward, dtype=torch.float32).unsqueeze(1).to(self.device)
next_state_batch = torch.tensor(np.array(batch.next_state), dtype=torch.float32).to(self.device)
done_batch = torch.tensor(batch.done, dtype=torch.float32).unsqueeze(1).to(self.device)
# 计算当前Q值
state_action_values = self.policy_net(state_batch).gather(1, action_batch)
# 计算目标Q值
with torch.no_grad():
next_state_values = self.target_net(next_state_batch).max(1)[0].unsqueeze(1)
expected_state_action_values = reward_batch + (self.gamma * next_state_values * (1 - done_batch))
# 计算损失
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
# 优化模型
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪
for param in self.policy_net.parameters():
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
def update_target_net(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
# 初始化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = DQNAgent(env.env.action_space.n, device)
5. 训练主循环
import matplotlib.pyplot as plt
from IPython import display
def plot_rewards(episode_rewards, mean_rewards):
display.clear_output(wait=True)
plt.figure(figsize=(10,5))
plt.title('Training Progress')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.plot(episode_rewards, label='Episode Reward')
plt.plot(mean_rewards, label='Mean Reward (100 episodes)')
plt.legend()
plt.grid()
plt.show()
def train(num_episodes=1000):
episode_rewards = []
mean_rewards = []
best_mean = -float('inf')
for i_episode in range(num_episodes):
state = env.reset()
total_reward = 0
while True:
# 选择并执行动作
action = agent.select_action(state)
next_state, reward, done = env.step(action)
total_reward += reward
# 存储转移样本
memory.push(state, action, next_state, reward, done)
state = next_state
# 优化模型
agent.optimize_model()
# 同步目标网络
if agent.steps_done % agent.target_update == 0:
agent.update_target_net()
if done:
break
# 记录训练进度
episode_rewards.append(total_reward)
mean_100 = np.mean(episode_rewards[-100:])
mean_rewards.append(mean_100)
# 定期保存最佳模型
if mean_100 > best_mean:
best_mean = mean_100
torch.save(agent.policy_net.state_dict(), 'dqn_best.pth')
# 显示进度
if i_episode % 10 == 0:
plot_rewards(episode_rewards, mean_rewards)
print(f"Episode {i_episode}, Reward: {total_reward}, Mean Reward: {mean_100:.1f}")
env.close()
return episode_rewards
# 开始训练
rewards = train(num_episodes=500)
完整项目
项目结构:
lesson_34_dqn_atari/
├── envs/
│ ├── atari_wrapper.py # 环境预处理
│ └── utils.py # 辅助函数
├── models/
│ ├── dqn.py # 网络架构
│ └── memory.py # 经验回放
├── agents/
│ └── dqn_agent.py # DQN智能体
├── configs/
│ └── hyperparams.yaml # 超参数配置
├── train.py # 训练脚本
├── eval.py # 评估脚本
├── requirements.txt # 依赖列表
└── README.md # 项目说明
requirements.txt
gymnasium[atari]==0.28.1
torch==2.0.1
opencv-python==4.7.0.72
numpy==1.24.3
matplotlib==3.7.1
pygame==2.3.0
tensorboard==2.12.0
configs/hyperparams.yaml
# 训练参数
batch_size: 32
gamma: 0.99
eps_start: 1.0
eps_end: 0.01
eps_decay: 10000
target_update: 1000
memory_capacity: 100000
learning_rate: 0.0001
# 环境参数
env_name: "PongNoFrameskip-v4"
frame_stack: 4
运行效果
训练过程输出
Episode 0, Reward: -21.0, Mean Reward: -21.0
Episode 10, Reward: -20.0, Mean Reward: -20.5
...
Episode 300, Reward: 12.0, Mean Reward: 8.2
Episode 500, Reward: 18.0, Mean Reward: 14.7
常见问题
Q1: 训练初期得分极低(如Pong总是-21)
解决方案:
- 这是正常现象,需要约100-200局"暖机期"
- 检查记忆库是否在积累足够样本后才开始训练
- 确保ε-greedy策略在初期有足够探索(ε_start=1.0)
Q2: 训练不稳定,成绩大起大落
可能原因:
- 目标网络更新频率不合适(调整target_update)
- 学习率过高(尝试减小到0.00001)
- 批量大小太小(增大到64或128)
Q3: 如何应用到其他Atari游戏?
调整建议:
- 修改图像预处理(不同游戏的有效区域不同)
- 调整奖励缩放(如Breakout需要正奖励)
- 可能需要更深的网络(如Nature DQN架构)
课后练习
-
Double DQN改进
实现Double DQN算法,比较与原始DQN的性能差异 -
优先级经验回放
修改记忆库,使重要转移样本有更高采样概率 -
Dueling DQN架构
将网络改为Dueling架构(分离状态价值和优势函数) -
多步TD学习
实现n-step TD回报,平衡MC和TD方法的优缺点
扩展阅读
- DQN原始论文(Human-level control through deep reinforcement learning)
- Rainbow: 组合多种改进的SOTA算法
- Stable Baselines3实现参考
下节预告:第35讲将探索策略梯度方法,实现连续控制任务!