本文将从REINFORCE算法的缺点入手,引入Actor-Critic网络的改进,并给出A2C的细节和实现代码。
总的来说,Actor-Critic同时吸收了value-based和policy-based下的优点,不过本质上还是从policy-based开始改进的,这一系列算法的目标都是优化策略网络参数,只是引入了学习价值函数的网络帮助策略评估。
1.奖励函数的改进
在先前基于策略的REINFORCE算法中,存在因蒙特卡洛采样引起的高方差的问题,针对单个轨迹,我们对优化目标
J
(
π
θ
)
J(\pi_\theta)
J(πθ)的奖励回报有多种改进方式:
∇
θ
J
(
π
θ
)
=
E
τ
∼
P
(
τ
∣
θ
)
[
∑
t
=
0
T
∇
θ
l
o
g
π
θ
(
a
t
∣
s
t
)
R
(
τ
)
]
R
(
τ
)
=
∑
t
′
=
0
γ
t
′
r
t
′
(
轨迹的总回报
)
=
Q
π
(
s
0
,
a
0
)
\nabla _\theta J(\pi_\theta)=E_{\tau\thicksim P(\tau|\theta)}[\sum_{t=0}^T\nabla _{\theta}log\space\pi_\theta(a_t|s_t)R(\tau)] \\R(\tau)=\sum_{t'=0}\gamma^{t'}r_{t'}(轨迹的总回报)=Q^\pi(s_0,a_0)
∇θJ(πθ)=Eτ∼P(τ∣θ)[t=0∑T∇θlog πθ(at∣st)R(τ)]R(τ)=t′=0∑γt′rt′(轨迹的总回报)=Qπ(s0,a0)
- 截断回报:只记录动作 a t a_t at之后的回报, R ( τ ) = ∑ t ′ = t γ t ′ − t r t ′ = Q π ( s t , a t ) R(\tau)=\sum_{t'=t}\gamma^{t'-t}r_{t'}=Q^\pi(s_t,a_t) R(τ)=∑t′=tγt′−trt′=Qπ(st,at)
- 基线:在截断回报的基础上改进 R ( τ ) = ∑ t ′ = t γ t ′ − t r t ′ − b ( s t ) R(\tau)=\sum_{t'=t}\gamma^{t'-t}r_{t'}-b(s_t) R(τ)=∑t′=tγt′−trt′−b(st)
由于蒙特卡洛需要整个轨迹结束才能进行求解,所以我们考虑能否在轨迹过程中更新。对于截断回报,我们可以利用Q-learning中的思想,用另一个神经网络对Q进行估计,得 R ( τ ) = Q π ( s t , a t ) R(\tau)=Q^\pi(s_t,a_t) R(τ)=Qπ(st,at)。
在此基础上考虑基线方法,常用的baseline可以用状态价值函数 V ( s t ) V(s_t) V(st),此处如果我们将动作价值函数和状态价值函数相减,我们可以得到优势函数的形式 R ( τ ) = A π θ ( s t , a t ) R(\tau)=A^{\pi_\theta}(s_t,a_t) R(τ)=Aπθ(st,at),我们还可以利用Bellman方程展开 Q = r + γ V Q=r+\gamma V Q=r+γV得出 R ( τ t ) = r t + γ V π ( s t + 1 ) − V π ( s t ) R(\tau_t)=r_t+\gamma V^\pi(s_{t+1})-V^\pi(s_t) R(τt)=rt+γVπ(st+1)−Vπ(st),可以发现,原本的REINFORCE算法已经被我们改成了时序差分形式,同时借鉴了value-based下估计状态价值函数 V π V^\pi Vπ的思路。
二、Actor-Critic网络结构
由于现在存在一个估计状态价值函数的网络和一个更新策略梯度的网络,我们将这种网络结构称为Actor-Critic,将他们分别称为Actor(策略网络),Critic价值网络
-
Actor负责与环境交互,在Critic的指导下用策略梯度学习更好的策略(需要Critic的计算的V值)
-
Critic负责将Actor收集的样本用于计算状态价值函数,可用于判断当前状态的好坏
针对策略网络Actor,我们依然采用REINFORCE算法中的优化目标,值得注意的是,现在轨迹的每一步都可以进行更新了,为了保持稳定,轨迹中途的
R
(
τ
)
R(\tau)
R(τ)不记录梯度
∇
θ
J
(
π
θ
)
=
E
τ
∼
P
(
τ
∣
θ
)
[
∑
t
=
0
T
∇
θ
l
o
g
π
θ
(
a
t
∣
s
t
)
R
(
τ
)
]
R
(
τ
t
)
=
r
t
+
γ
V
π
(
s
t
+
1
)
−
V
π
(
s
t
)
\nabla _\theta J(\pi_\theta)=E_{\tau\thicksim P(\tau|\theta)}[\sum_{t=0}^T\nabla _{\theta}log\space\pi_\theta(a_t|s_t)R(\tau)]\\R(\tau_t)=r_t+\gamma V^\pi(s_{t+1})-V^\pi(s_t)
∇θJ(πθ)=Eτ∼P(τ∣θ)[t=0∑T∇θlog πθ(at∣st)R(τ)]R(τt)=rt+γVπ(st+1)−Vπ(st)
针对价值网络Critic,令价值网络表示为
V
ω
V_\omega
Vω,我们借鉴DQN中估计时序差分的动作状态价值函数的损失函数:
D
Q
N
:
L
o
s
s
=
E
[
(
(
Q
(
s
,
a
;
θ
)
−
(
r
+
γ
m
a
x
a
′
Q
(
s
′
,
a
′
;
θ
−
)
)
)
2
)
]
DQN:Loss=E[((Q(s,a;θ)−(r+γmax_{a′}Q(s′,a′;θ^-)))^2)]
DQN:Loss=E[((Q(s,a;θ)−(r+γmaxa′Q(s′,a′;θ−)))2)]
在价值函数中我们为了保持网络稳定,同样使用目标网络的思路,将
r
+
γ
V
ω
(
s
t
+
1
)
r+\gamma V_\omega(s_{t+1})
r+γVω(st+1)作为时序差分目标,使其不产生梯度更新价值函数,价值函数表示为:
C
r
i
t
i
c
:
L
(
ω
)
=
1
2
(
r
+
γ
V
ω
(
s
t
+
1
)
−
V
ω
(
s
t
)
)
2
Critic:L(\omega)=\frac{1}{2}(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))^2
Critic:L(ω)=21(r+γVω(st+1)−Vω(st))2
因此价值函数的梯度为
∇
C
r
i
t
i
c
:
∇
ω
L
(
ω
)
=
−
(
r
+
γ
V
ω
(
s
t
+
1
)
−
V
ω
(
s
t
)
)
∇
ω
V
ω
(
s
t
)
\nabla Critic:\nabla_{\omega}L(\omega)=-(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))\nabla_{\omega}V_{\omega}(s_t)
∇Critic:∇ωL(ω)=−(r+γVω(st+1)−Vω(st))∇ωVω(st)
三、优势Actor-Critic:A2C
在上文中,当我们将 R ( τ ) R(\tau) R(τ)将时序差分和优势函数的概念结合时,这就是Advantage Actor-Critic,简称为A2C
Actor-Crtic的形式多种多样,但A2C具体指的就是时序差分和优势函数结合的这种形式,上一节以A2C形式分析了Actor-Critic的细节
R ( τ t ) = r t + γ V π ( s t + 1 ) − V π ( s t ) R(\tau_t)=r_t+\gamma V^\pi(s_{t+1})-V^\pi(s_t) R(τt)=rt+γVπ(st+1)−Vπ(st)
四、异步优势Actor-Critic:A3C
先前的REINFORCE算法还有一个造成高方差的因素:轨迹随机性——在环境随机性较大的情况下,相近迭代的策略的轨迹也有非常不同的结果。我们既希望A2C能够有DQN中ReplyBuffer每次更新取多样本的方法,又由于policy-based下使用ReplayBuffer违背了on-policy的原理,我们考虑一种类似ReplyBuffer的操作——一次性计算多条轨迹,可利用多条轨迹的样本更新价值函数——既能降低样本相关性,减小方差,又能满足on-policy同分布的原则。这就是A3C(Asynchronous A2C,异步优势A-C)
异步更新使用相近轨迹采样,依然没有达到样本间独立的程度,只是减少了样本相关性
A3C利用了多线程的实现多个agent,A3C定义了多个相同的agent网络,每个agent均使用A2C结构,但是他们的分工不同。存在一个global agent和多个worker agent,worker agent使用当前策略与环境进行交互,将轨迹信息上传给global agent。global agent负责将所有worker采集的轨迹信息汇总,再将其发送给每一个worker agent。A3C相比于A2C没有算法上的革新,难的主要是怎么写一个多线程来跑多个agent,下面直接上代码,我们从代码层面讲解 global 和 worker具体做了什么。
五、A2C代码实现
我们先对A2C开刀,理解下 Actor 和 Critic 网络之间的关系
- 策略网络Actor部分,类似REINFORCE中更新策略
class Actor(nn.Module):
def __init__(self,state_dim,hidden_dim,action_dim):
super(Actor,self).__init__()
self.fc1 = nn.Linear(state_dim,hidden_dim)
self.fc2 = nn.Linear(hidden_dim,action_dim)
def forward(self,x):
x = F.relu(self.fc1(x))
return F.softmax(self.fc2(x),dim=1) #区别在此
- 价值网络Crtic,类似DQN中评估价值
class Crtic(nn.Module):
def __init__(self,state_dim,hidden_dim):
super(Actor,self).__init__()
self.fc1 = nn.Linear(state_dim,hidden_dim)
self.fc2 = nn.Linear(hidden_dim)
def forward(self,x):
x = F.relu(self.fc1(x))
return self.fc2(x)
区别1:输出不同,Actor输出策略-要求不同动作概率总和为1(用softmax约束),Critic输出值
区别2:action_dim不同,Crtic输出一个值,所以action_dim=1
Actor-Critic网络:
class A2C:
def __init__(self,state_dim,hidden_dim,action_dim,actor_lr,critic_lr,gamma,device):
self.actor = Actor(state_dim,hidden_dim,action_dim).to(device)
self.critic = Crtic(state_dim,hidden_dim).to(device)
self.optimizer_actor = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)#对Actor进行优化
self.optimizer_critic = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)#对Critic进行优化
self.gamma = gamma
self.device = device
#Actor进行决策,仿照reinforce
def take_action(self,state):
state = torch.tensor([state],dtype=torch.float).to(self.device)
action_probs = self.actor(state)
action_dist = torch.distributions.Categorical(action_probs)
action = action_dist.sample()
return action.item()
#对Actor和Critic进行更新
def update(self,transition_dict):
rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1,1).to(self.device)
states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)
actions = torch.tensor(transition_dict['actions']).view(-1,1).to(self.device)
dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1,1).to(self.device)
next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)
#计算时序差分目标
v_value = self.critic(states)
next_v_value = self.critic(next_states)
td_target = rewards + self.gamma * next_v_value * (1 - dones)
td_delta = td_target - v_value
#对actor进行优化
log_probs = torch.log(self.actor(states).gather(1,actions))
actor_loss = torch.mean(-log_probs * td_delta.detach())#保持轨迹内的稳定性,不记录td-delta的梯度
#对critic进行优化
critic_loss = torch.mean(F.mse_loss(v_value,td_target.detach()))#价值函数中,我们使得目标函数稳定,所以不对td_target传播梯度
self.optimizer_actor.zero_grad()
self.optimizer_critic.zero_grad()
actor_loss.backward()
critic_loss.backward()
self.optimizer_critic.step()
self.optimizer_actor.step()
下面截取main的主要部分:
for i in range(10):
with tqdm(total=int(args.num_episodes/10),desc="Iteration %d" % i ) as pbar:
for i_episode in range(int(args.num_episodes/10)):
episode_return = 0
state = env.reset()
state = state[0]
done = False
transition_dict = {
"states":[],
"actions":[],
"rewards":[],
"next_states":[],
"dones":[],
}
while not done:
action = agent.take_action(state)
next_state,reward,done, _ , _ = env.step(action)
transition_dict["states"].append(state)
transition_dict["actions"].append(action)
transition_dict["rewards"].append(reward)
transition_dict["next_states"].append(next_state)
transition_dict["dones"].append(done)
state = next_state
episode_return += reward
return_list.append(episode_return)
agent.update(transition_dict)
更新策略是,一次更新使用一条轨迹,一轮有多次更新,单次轨迹内使用时序差分更新