目标网络的评估模式:target_dqn.eval()

本文解释了深度强化学习中target_dqn.eval()函数在DQN算法中的作用,特别是如何在训练和评估模式下处理Dropout和BatchNormalization层。强调了在模型评估和性能测试时正确设置评估模式的重要性,以保证结果的一致性和准确性。
摘要由CSDN通过智能技术生成

在深度强化学习(DRL)中,特别是在深度Q网络(DQN)算法的实现中,target_dqn.eval() 是一个重要的函数调用,它涉及到模型的两种模式:训练模式(train mode)和评估模式(evaluation mode,或者称为推理模式 inference mode)。

target_dqn.eval() 的作用是将目标DQN网络设置为评估模式。在这种模式下,网络中的某些层(如Dropout层和BatchNormalization层)会改变它们的行为:

  1. Dropout层:在训练模式下,Dropout层会随机地将输入张量中的一部分元素设置为0,以防止过拟合。但在评估模式下,Dropout层不会丢弃任何元素,而是会对所有元素进行缩放,以确保输出的总和与训练模式下相同(在统计意义上)。

  2. BatchNormalization层:BatchNormalization层在训练时会使用mini-batch的统计数据进行归一化,并使用可学习的缩放(gamma)和偏移(beta)参数进行调整。在评估模式下,BatchNormalization层会使用在训练阶段计算得到的运行均值和运行方差来进行归一化,而不是使用mini-batch的统计数据。

将模型设置为评估模式是很重要的,特别是在进行模型推断或评估模型性能时。如果在进行模型评估或测试时没有将模型设置为评估模式,那么由于Dropout层或BatchNormalization层的行为差异,可能会得到不一致或不可预测的结果。

在DQN算法中,目标DQN网络(target_dqn)通常用于计算目标Q值,以便稳定学习过程。因此,在计算目标Q值之前,将目标DQN设置为评估模式是很重要的,以确保得到一致且准确的目标值。

lr = 2e-3 num_episodes = 500 hidden_dim = 128 gamma = 0.98 epsilon = 0.01 target_update = 10 buffer_size = 10000 minimal_size = 500 batch_size = 64 device = torch.device("cuda") if torch.cuda.is_available() else torch.device( "cpu") env_name = 'CartPole-v1' env = gym.make(env_name) random.seed(0) np.random.seed(0) #env.seed(0) torch.manual_seed(0) replay_buffer = ReplayBuffer(buffer_size) state_dim = env.observation_space.shape[0] action_dim = env.action_space.n agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device) return_list = [] episode_return = 0 state = env.reset()[0] done = False while not done: action = agent.take_action(state) next_state, reward, done, _, _ = env.step(action) replay_buffer.add(state, action, reward, next_state, done) state = next_state episode_return += reward # 当buffer数据的数量超过一定值后,才进行Q网络训练 if replay_buffer.size() > minimal_size: b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size) transition_dict = { 'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d } agent.update(transition_dict) if agent.count >=200: #运行200步后强行停止 agent.count = 0 break return_list.append(episode_return) episodes_list = list(range(len(return_list))) plt.plot(episodes_list, return_list) plt.xlabel('Episodes') plt.ylabel('Returns') plt.title('DQN on {}'.format(env_name)) plt.show()对上述代码的每一段进行注释,并将其在段落中的作用注释出来
06-12
``` lr = 2e-3 # 学习率 num_episodes = 500 # 训练的总Episode数 hidden_dim = 128 # 隐藏层维度 gamma = 0.98 # 折扣因子 epsilon = 0.01 # ε贪心策略中的ε值 target_update = 10 # 目标网络更新频率 buffer_size = 10000 # 经验回放缓冲区的最大容量 minimal_size = 500 # 经验回放缓冲区的最小容量,达到此容量后才开始训练 batch_size = 64 # 每次训练时的样本数量 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # 选择CPU或GPU作为运行设备 env_name = 'CartPole-v1' # 使用的环境名称 env = gym.make(env_name) # 创建CartPole-v1环境 random.seed(0) # 随机数生成器的种子 np.random.seed(0) # 随机数生成器的种子 torch.manual_seed(0) # 随机数生成器的种子 replay_buffer = ReplayBuffer(buffer_size) # 创建经验回放缓冲区 state_dim = env.observation_space.shape[0] # 状态空间维度 action_dim = env.action_space.n # 动作空间维度(离散动作) agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device) # 创建DQN智能体 return_list = [] # 用于存储每个Episode的回报 episode_return = 0 # 每个Episode的初始回报为0 state = env.reset()[0] # 环境的初始状态 done = False # 初始状态下没有结束 ``` 以上代码是对程序中所需的参数进行设置和初始化,包括学习率、训练的总Episode数、隐藏层维度、折扣因子、ε贪心策略中的ε值、目标网络更新频率、经验回放缓冲区的最大容量、经验回放缓冲区的最小容量、每次训练时的样本数量、运行设备、使用的环境名称等等。同时,创建了经验回放缓冲区、DQN智能体和用于存储每个Episode的回报的列表,以及初始化了环境状态和结束标志。 ``` while not done: action = agent.take_action(state) # 智能体根据当前状态选择动作 next_state, reward, done, _, _ = env.step(action) # 环境执行动作,观测下一个状态、奖励和结束标志 replay_buffer.add(state, action, reward, next_state, done) # 将当前状态、动作、奖励、下一个状态和结束标志添加到经验回放缓冲区中 state = next_state # 更新状态 episode_return += reward # 累加当前Episode的回报 ``` 以上代码是智能体与环境的交互过程,智能体根据当前状态选择动作,环境执行动作并返回下一个状态、奖励和结束标志,将当前状态、动作、奖励、下一个状态和结束标志添加到经验回放缓冲区中,更新状态,并累加当前Episode的回报。 ``` if replay_buffer.size() > minimal_size: # 当经验回放缓冲区的数据量达到最小容量时,开始训练 b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size) # 从经验回放缓冲区中采样样本 transition_dict = { 'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d } agent.update(transition_dict) # 智能体根据样本更新Q网络 if agent.count >=200: # 运行200步后强行停止 agent.count = 0 break ``` 以上代码是经验回放和Q网络更新过程,当经验回放缓冲区的数据量达到最小容量时,从经验回放缓冲区中采样样本,智能体根据样本更新Q网络。同时,当运行步数超过200步时,强制停止训练。 ``` return_list.append(episode_return) # 将当前Episode的回报添加到回报列表中 ``` 以上代码是将当前Episode的回报添加到回报列表中。 ``` episodes_list = list(range(len(return_list))) # 横坐标为Episode序号 plt.plot(episodes_list, return_list) # 绘制Episode回报随序号的变化曲线 plt.xlabel('Episodes') plt.ylabel('Returns') plt.title('DQN on {}'.format(env_name)) plt.show() ``` 以上代码是绘制Episode回报随序号的变化曲线。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值