gym-0.26.1
CartPole-v1
详细信息
gym 最新的版本是 0.26.2,但是之后要不维护了,转到了Gymnasium版本是0.27.0。
DQN with target net
论文参考 Human-level control through deep reinforcement learning
目标网络缓解了由自举带来的误差传递,但是由于有
max
操作,所以高估问题还是比较严重,之后的double DNQ
会缓解这个问题。
对动手深度强化学习里的代码做了一些修改。
代码如下
import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import random
import collections
from tqdm import tqdm
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils
class ReplayBuffer:
"""经验回放池"""
def __init__(self, capacity):
self.buffer = collections.deque(maxlen=capacity) # 队列,先进先出
def add(self, state, action, reward, next_state, done): # 将数据加入buffer
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size): # 从buffer中采样数据,数量为batch_size
transition = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = zip(*transition)
return np.array(state), action, reward, np.array(next_state), done
def size(self): # 目前buffer中数据的数量
return len(self.buffer)
class Q(nn.Module):
"""只有一层隐藏层的Q网络"""
def __init__(self, state_dim, hidden_dim, action_dim):
super().__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, action_dim)
def forward(self, X):
X = F.relu(self.fc1(X)) # 隐藏层之后使用ReLU激活函数
return self.fc2(X)
class DQN:
"""DQN算法"""
def __init__(self, state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device):
self.action_dim = action_dim
self.q = Q(state_dim, hidden_dim, action_dim).to(device) # Q网络
self.target_q = Q(state_dim, hidden_dim, action_dim).to(device) # 目标网络
self.target_q.load_state_dict(self.q.state_dict()) # 加载参数
self.optimizer = torch.optim.Adam(self.q.parameters(), lr=lr)
self.gamma = gamma
self.epsilon = epsilon
self.target_update = target_update # 目标网络更新频率
self.count = 0 # 计数器,记录更新次数
self.device = device
def take_action(self, state): # epsilon-贪婪策略
if np.random.random() < self.epsilon:
action = np.random.randint(self.action_dim)
else:
state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
action = self.q(state).argmax().item()
return action
def update(self, transition_dict):
states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)
rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).reshape(-1,1).to(self.device)
next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)
q_values = self.q(states).gather(1, actions) # Q值
# 下个状态的最大Q值
max_next_q_values = self.target_q(next_states).max(1)[0].reshape(-1,1)
q_targets = rewards + self.gamma * max_next_q_values * (1- dones) # TD误差
loss = F.mse_loss(q_values, q_targets) # 均方误差
self.optimizer.zero_grad() # 梯度清零,因为默认会梯度累加
loss.mean().backward() # 反向传播
self.optimizer.step() # 更新梯度
if self.count % self.target_update == 0:
self.target_q.load_state_dict(self.q.state_dict())
self.count += 1
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = d2l.try_gpu()
print(device)
env_name = "CartPole-v1"
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device)
return_list = []
for i in range(10):
with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:
for i_episode in range(int(num_episodes/10)):
episode_return = 0
state = env.reset()[0]
done, truncated= False, False
while not done and not truncated :
action = agent.take_action(state)
next_state, reward, done, truncated, info = env.step(action)
replay_buffer.add(state, action, reward, next_state, done)
state = next_state
episode_return += reward
# 当buffer数据的数量超过一定值后,才进行Q网络训练
if replay_buffer.size() > minimal_size:
b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}
agent.update(transition_dict)
return_list.append(episode_return)
if (i_episode+1) % 10 == 0:
pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1),
'return': '%.3f' % np.mean(return_list[-10:])})
pbar.update(1)
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'DQN on {env_name}')
plt.show()
mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'DQN on {env_name}')
plt.show()
同样的在jupyter
中运行,结果如下
从结果上来看,网络在两个阶段,性能得到了快速提升,但后一个阶段振荡加剧了,应该是过拟合的问题加剧了。