第0-(3)章-DRL的细碎笔记-Q学习中的Q网络如何训练
作者:想要飞的猪
工作地点:北京科技大学
DQN与PPO以及DDPG等算法不同,DQN仅有一个作为Critic的Q网络。不同于PPO以及DDPG的Critic网络仅输出一个Q值,DQN处理离散动作时Q网络最终的输出是对应不同离散动作的多个Q值,然后选取其中Q值最大的动作。DQN的网络结构如下图(图片来自于李宏毅老师的深度强化学习课程):
为了更好的理解多个输出的Q网络如何训练,这次的博客结合DQN中Q网络训练时的损失函数,分析相应的代码实现。
1. DQN的损失函数
在DQN中,Q网络的损失函数为:
L
(
θ
i
)
=
E
(
s
,
a
,
r
,
s
′
)
∼
U
(
D
)
[
(
y
−
Q
(
s
,
a
;
θ
i
)
)
2
]
L\left(\theta_i\right)=\mathbb{E}_{\left(s, a, r, s^{\prime}\right) \sim U(D)}\left[\left(y-Q\left(s, a ; \theta_i\right)\right)^2\right]
L(θi)=E(s,a,r,s′)∼U(D)[(y−Q(s,a;θi))2]
其中,
y
y
y是目标值,
y
=
r
+
γ
max
a
′
Q
(
s
′
,
a
′
;
θ
i
−
)
y=r+\gamma \max _{a^{\prime}} Q\left(s^{\prime}, a^{\prime} ; \theta_i^{-}\right)
y=r+γmaxa′Q(s′,a′;θi−)。
根据
L
(
θ
i
)
L\left(\theta_i\right)
L(θi)的表达式,里面两次用到了Q网络,一次用到了需要迭代学习的Q网络,另外一次用到了target Q网络。在Q网络的输出中有多个Q值,需要根据策略选取某个Q值。首先,先展示一下DQN中Q网络的架构,示例代码如下:
class QNet(nn.Module):
def __init__(self, state_dim, action_dim):
super(DQN, self).__init__()
self.fc1 = nn.Linear(state_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, action_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
从上述代码可以看出,最后Q网络最后返回值的维度是离散动作的维度。在代码中从多个离散动作中选取某个动作以及构建损失函数的代码如下:
class DQN():
...
def select_action(self, state, eps):
if random.random() < eps:
return random.randint(0, self.action_dim - 1)
else:
state = torch.FloatTensor(state).to(self.device)
with torch.no_grad():
action = self.policy_net(state).argmax().item()
return action
def store_transition(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def train(self):
if len(self.memory) < self.batch_size:
return
transitions = random.sample(self.memory, self.batch_size)
batch = list(zip(*transitions))
state_batch = torch.FloatTensor(batch[0]).to(self.device)
action_batch = torch.LongTensor(batch[1]).to(self.device)
reward_batch = torch.FloatTensor(batch[2]).to(self.device)
next_state_batch = torch.FloatTensor(batch[3]).to(self.device)
done_batch = torch.FloatTensor(batch[4]).to(self.device)
q_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1)).squeeze(1)
next_q_values = self.target_net(next_state_batch).max(1)[0]
expected_q_values = reward_batch + self.gamma * next_q_values * (1 - done_batch)
loss = self.loss_fn(q_values, expected_q_values.detach())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.steps += 1
self.writer.add_scalar("Loss", loss.item(), self.steps)
...
在上述代码中需要注意的几个点:
(1)实现
L
(
θ
i
)
L\left(\theta_i\right)
L(θi)中
Q
(
s
,
a
;
θ
i
)
Q\left(s, a ; \theta_i\right)
Q(s,a;θi)的代码为:
q_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1)).squeeze(1)
其中,.gather()函数的作用是依据action_batch中对应的动作索引选择policy_net多个Q值中对应的Q值。unsqueeze与squeeze就是在变换数据的维度信息,使得数据的输入输出符合运算的需求。
(2)实现
L
(
θ
i
)
L\left(\theta_i\right)
L(θi)中
y
y
y的代码为:
next_q_values = self.target_net(next_state_batch).max(1)[0]
expected_q_values = reward_batch + self.gamma * next_q_values * (1 - done_batch)
需要注意的是,Q网络训练时,只有与当前动作(action_batch)对应的输出层参数会参与计算损失函数和梯度更新。其他输出层参数的梯度将为零,因为它们不参与计算损失函数和目标值(这个地方的意思是,如果当前Q网络中动作1被选取,那么神经网络backward只是从当前被选取的动作1开始backward,其他动作对应的输出层参数不会backward)。这样做的目的是确保只有与当前动作有关的参数会被更新。
(3)对batch进行分类整理的代码为:
transitions = random.sample(self.memory, self.batch_size)
batch = list(zip(*transitions))
state_batch = torch.FloatTensor(batch[0]).to(self.device)
action_batch = torch.LongTensor(batch[1]).to(self.device)
reward_batch = torch.FloatTensor(batch[2]).to(self.device)
next_state_batch = torch.FloatTensor(batch[3]).to(self.device)
done_batch = torch.FloatTensor(batch[4]).to(self.device)
这段代码利用zip以及*对batch数据进行分类处理,具体的内容可以看下面的补充知识点(学习代码最重要的是一边跑代码一边学习,然后在不确定的地方加个print看一下,这块内容的我也是一边查资料一边通过python跑代码理解的,如果对数据的shape不确定就print一下数据的shape,如果是type就print一下type)。
(4)上述代码中select_action函数的作用是接收到某个state然后给出对应的action的索引,索引的获得由以下代码实现:
action = self.policy_net(state).argmax().item()
其中,item()获取最大 Q 值的动作索引,同时item()将张量转换为 Python 基本类型时会自动剥离梯度,这样就不会对梯度更新产生影响(这一点在其他算法中的choose_action函数中也需要注意,需要在最后返回action时将梯度剥离)。
2. 补充知识点
(1)*
在Python中,星号(*)被用来进行解包(unpacking)操作,可以将一个可迭代对象(如列表、元组等)中的元素分别解开,作为独立的位置参数传递给函数或构造新的序列。例子如下:
transitions = [ (1, 'a'), (2, 'b'), (3, 'c') ]
print(*transitions)
运行结果如下:
(1, 2, 3), ('a', 'b', 'c')
(2)zip
在使用pytorch编写DRL的代码时,通常会使用到zip函数。zip函数的作用是就是把不同可迭代对象中对应的元素对应取出,组成一个元组。这样解释有些晦涩,直接看下面的例子:
lett = ['a', 'b']
num = [1, 2]
for x, y in zip(lett, num):
print('x:', x, 'y:', y)
print('x_type:', type(x), 'y_type:', type(y))
for lett_num in zip(lett, num):
print('lett_num:', lett_num, 'lett_num_type:', type(lett_num))
运行结果如下:
x: a y: 1
x_type: <class 'str'> y_type: <class 'int'>
x: b y: 2
x_type: <class 'str'> y_type: <class 'int'>
lett_num: ('a', 1) lett_num_type: <class 'tuple'>
lett_num: ('b', 2) lett_num_type: <class 'tuple'>
可以看出如果使用x,y依次取出zip(lett, num)中的数据,zip返回的数据类型不是元组,而仅使用lett_num取出zip的数据,zip返回的数据类型是tuple。zip使得不同的可迭代对象之间对应元素的运算更为简洁。
如果我写的内容对大家有帮助欢迎大家关注我,我的博客后续会持续更新,每一篇的内容如有补充也会进行精修,所涉及的代码资源也将在允许范围之内尽力提供给大家。我曾经受过很多人无私的帮助,所以我也想为知识的传播降低门槛,欢迎关注!
感恩所有帮助过我的人!