DQN 跑 CartPole-v1

gym-0.26.1
CartPole-v1 详细信息

gym 最新的版本是 0.26.2,但是之后要不维护了,转到了Gymnasium版本是0.27.0。

DQN with target net 论文参考 Human-level control through deep reinforcement learning

目标网络缓解了由自举带来的误差传递,但是由于有max操作,所以高估问题还是比较严重,之后的double DNQ会缓解这个问题。

动手深度强化学习里的代码做了一些修改。
代码如下

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import random
import collections
from tqdm import tqdm
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils

class ReplayBuffer:
    """经验回放池"""
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity) # 队列,先进先出
    
    def add(self, state, action, reward, next_state, done): # 将数据加入buffer
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size): # 从buffer中采样数据,数量为batch_size
        transition = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transition)
        return np.array(state), action, reward, np.array(next_state), done
    
    def size(self): # 目前buffer中数据的数量
        return len(self.buffer)

class Q(nn.Module):
    """只有一层隐藏层的Q网络"""
    def __init__(self, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, X):
        X = F.relu(self.fc1(X)) # 隐藏层之后使用ReLU激活函数
        return self.fc2(X)

class DQN:
    """DQN算法"""
    def __init__(self, state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device):
        self.action_dim = action_dim
        self.q = Q(state_dim, hidden_dim, action_dim).to(device) # Q网络
        self.target_q = Q(state_dim, hidden_dim, action_dim).to(device) # 目标网络
        self.target_q.load_state_dict(self.q.state_dict())  # 加载参数
        self.optimizer = torch.optim.Adam(self.q.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.target_update = target_update # 目标网络更新频率
        self.count = 0 # 计数器,记录更新次数
        self.device = device
    
    def take_action(self, state): # epsilon-贪婪策略
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
            action = self.q(state).argmax().item()
        return action
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).reshape(-1,1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)
        
        q_values = self.q(states).gather(1, actions) # Q值
        # 下个状态的最大Q值
        max_next_q_values = self.target_q(next_states).max(1)[0].reshape(-1,1)
        q_targets = rewards + self.gamma * max_next_q_values * (1- dones) # TD误差
        loss = F.mse_loss(q_values, q_targets) # 均方误差
        self.optimizer.zero_grad() # 梯度清零,因为默认会梯度累加
        loss.mean().backward() # 反向传播
        self.optimizer.step() # 更新梯度
        
        if self.count % self.target_update == 0:
            self.target_q.load_state_dict(self.q.state_dict())
        self.count += 1
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = d2l.try_gpu()
print(device)

env_name = "CartPole-v1"
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device)
return_list = []

for i in range(10):
    with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:
        for i_episode in range(int(num_episodes/10)):
            episode_return = 0
            state = env.reset()[0]
            done, truncated= False, False
            while not done and not truncated :
                action = agent.take_action(state)
                next_state, reward, done, truncated, info = env.step(action)
                replay_buffer.add(state, action, reward, next_state, done)
                state = next_state
                episode_return += reward
                # 当buffer数据的数量超过一定值后,才进行Q网络训练
                if replay_buffer.size() > minimal_size:
                    b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                    transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}
                    agent.update(transition_dict)
            return_list.append(episode_return)
            if (i_episode+1) % 10 == 0:
                pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 
                                  'return': '%.3f' % np.mean(return_list[-10:])})
            pbar.update(1)
            
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'DQN on {env_name}')
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'DQN on {env_name}')
plt.show()

同样的在jupyter中运行,结果如下



从结果上来看,网络在两个阶段,性能得到了快速提升,但后一个阶段振荡加剧了,应该是过拟合的问题加剧了。

  • 10
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: DQN (Deep Q-Network) 是一种基于深度神经网络的强化学习算法,它可用于解决强化学习环境中的问题。PyTorch 是一个开源的深度学习框架,提供了一种简单而强大的方式来构建和训练神经网络模型。CartPole-v0 是 OpenAI Gym 提供的一个强化学习环境,目标是控制一个摆杆平衡在垂直位置上。 在使用 PyTorch 实现 DQN 解决 CartPole-v0 问题时,需要首先定义一个深度神经网络模型作为 Q 函数的近似。这个模型通常包含若干隐藏层和一个输出层,用于预测在给定状态下采取各个动作的 Q 值。 然后,需要定义一个经验回放(Experience Replay)的缓冲区,用于存储智能体在环境中的经验,包括当前状态、动作、奖励和下一个状态。 接下来,使用 epsilon-greedy 策略选择动作,epsilon 表示随机探索的概率,即以一定概率选择随机动作,以一定概率选择当前 Q 值最大的动作。 将选择的动作应用于环境中,观察下一状态和奖励,并将这些经验存储到经验回放缓冲区中。 每隔一定步数,从经验回放缓冲区中采样一批数据,然后利用这些样本数据来更新神经网络的参数。DQN 使用经验回放的方式进行训练,这样可以减少样本间的相关性,提高样本的利用效率。 通过反向传播算法计算损失函数,并利用优化器更新神经网络的参数,使得神经网络的输出 Q 值逼近真实的 Q 值。 重复进行上述步骤,直到智能体能够有效地平衡摆杆,或者达到预定的训练次数。 在实际实现 DQN 算法过程中,还需要注意学习速率、discount factor 等超参数的选择,以及选择合适的损失函数和优化器来训练神经网络模型。 总结来说,使用 PyTorch 实现 DQN 来解决 CartPole-v0 问题,需要先定义一个深度神经网络模型作为 Q 函数的近似,然后利用经验回放的方式进行训练,通过反向传播算法来更新神经网络参数,使模型能够逼近真实的 Q 值,最终达到使摆杆平衡的目标。 ### 回答2: DQN(深度Q网络)是一种强化学习算法,用于解决各种控制问题,包括CartPole-v0这个经典的强化学习环境。PyTorch是一种深度学习框架,可以方便地构建神经网络模型。 在使用PyTorch实现DQN解决CartPole-v0问题时,我们首先需要定义网络模型。可以使用PyTorch提供的nn模块创建一个多层感知机网络,包含输入层、若干隐藏层和输出层。这个网络的输入是CartPole-v0环境的状态,输出是动作的Q值。使用ReLU作为激活函数可以增加网络的非线性表示能力。 定义好网络模型后,我们需要定义DQN的训练过程。首先,根据当前环境状态输入网络获取各个动作的Q值,然后选择Q值最大的动作作为当前的行动。执行动作后,环境将返回下一个状态、奖励和是否结束的信息。将这些信息存储在经验回放缓冲区中。 接下来,我们从经验回放缓冲区中随机采样一批数据,包括之前的状态、行动、奖励和下一个状态。然后,使用目标网络(Target Network)计算下一个状态的Q值,并根据贝尔曼方程计算当前状态的目标Q值。通过最小化当前状态的动作Q值和目标Q值的差距,更新网络的参数。 在DQN的训练过程中,还需要设置超参数,包括学习率、批大小、epsilon-greedy策略的参数等。为了提高收敛速度和稳定性,可以使用经验回放和目标网络两个技术。 最后,通过多次迭代训练,不断优化网络参数,直到DQN模型在CartPole-v0环境上能够稳定地获得较高的得分。 总之,使用PyTorch实现DQN算法解决CartPole-v0问题需要定义网络模型、训练过程和超参数,并使用经验回放和目标网络等技术进行优化,以提高性能和稳定性。 ### 回答3: DQN是一种使用深度神经网络进行强化学习的算法,它使用PyTorch框架实现,在CartPole-v0环境中非常有用。 CartPole-v0是一个经典的强化学习问题,任务是控制一个平衡杆,使其在变化的条件下保持平衡。这个环境具有四个状态变量:杆的角度、杆的速度、小车的位置和小车的速度。在每个时间步骤,智能体可以向左或向右施加力来控制小车的动作。目标是使杆保持在竖直位置,并且尽可能长时间地保持平衡。 DQN算法使用了深度神经网络来估计每种动作的Q值函数。在PyTorch中,我们可以使用nn.Module类创建深度神经网络模型,可以包含一些全连接层和非线性激活函数。DQN算法还使用了经验回放机制和目标网络来提高训练效果。 在CartPole-v0中,我们可以使用PyTorch中的torchvision.transforms对环境状态进行处理。然后,我们可以使用DQN模型以一定的epsilon-greedy策略来选择动作,并与环境进行交互。每个时间步之后,我们从经验回放缓冲区中随机样本一批数据,然后计算损失并更新网络参数。我们还会定期更新目标网络的权重,以确保稳定的学习过程。 通过使用DQN算法和PyTorch框架,我们可以在CartPole-v0环境中实现高效的强化学习训练。我们可以通过调整网络结构、超参数和训练步骤来提高性能,并使智能体在该环境中获得长时间平衡杆的能力。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值