先来个代码片段:
#计算q值:策略网络
#gather 函数用于根据给定的索引从一个张量中收集数据,1 表示沿着第二个维度进行索引
#action_batch.unsqueeze(1) 是为了增加一个额外的维度,使得 action_batch 可以作为 gather 函数的索引参数
#squeeze 函数用于移除张量中大小为1的维度,1 表示要移除的维度是第二个维度。
#在 gather 操作之后,由于只选择了一个动作,所以在动作维度上张量的大小变成了1。squeeze(1) 是用来移除这个大小为1的维度,使得结果张量的形状更加紧凑
q_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1)).squeeze(1)
#计算下一个状态的q值:目标网络
#.max(1) 是在得到的Q值上执行的操作,max函数返回两个值:最大值和最大值的索引。这里的1表示沿着第二个维度进行操作。
#[0] 表示只关心最大值,不关心其索引
next_q_values = self.target_net(next_state_batch).max(1)[0]
expected_q_values = reward_batch + self.gamma * next_q_values * (1 - done_batch)
在强化学习中,特别是在使用Q-learning或深度Q网络(DQN)算法时,经常会遇到两个网络:一个是主网络(通常称为policy_net
、online_net
或q_network
),另一个是目标网络(通常称为target_net
)。这两个网络的结构是相同的,但它们的权重是不同步更新的。
使用两个网络的原因是为了稳定学习过程。在DQN算法中,目标网络的权重是定期从主网络复制的,但更新的频率远低于主网络。这种做法有助于减少目标Q值(即预期的回报)与当前Q值(即实际的回报)之间的相关性,从而减少了学习过程中的波动性。
现在,来解释为什么next_q_values
使用target_net
而q_values
使用policy_net
:
- q_values:
q_values
代表在当前状态下采取实际执行的动作所对应的Q值。- 这些Q值是通过主网络(
policy_net
)计算得到的,因为主网络是实时更新的,它反映了最新的策略或价值估计。 - 在给定的状态
state_batch
下,通过policy_net
计算所有可能动作的Q值,然后使用gather
函数根据实际采取的动作action_batch
来选择对应的Q值。
- next_q_values:
next_q_values
代表在下一个状态(next_state_batch
)下可能获得的最高Q值。- 这些Q值是通过目标网络(
target_net
)计算得到的。由于目标网络的权重更新频率较低,它提供了一个更稳定的目标来更新主网络。 - 使用
.max(1)[0]
是为了从目标网络输出的Q值矩阵中找到每个状态对应的最大Q值,这代表了在该状态下可能获得的最高预期回报。
总结来说,使用两个网络(主网络和目标网络)是DQN算法中的一个关键技巧,它有助于减少学习过程中的不稳定性。policy_net
用于选择当前的动作并评估其Q值,而target_net
则用于估计未来状态的Q值,从而指导主网络的更新。