SAC(Soft Actor Critic)学习记录

SAC(Soft Actor Critic)学习记录

基本介绍

SAC(Soft Actor Critic)算法在近年来受到了许多的关注,得到了不少深度强化学习研究者的好评。这篇文章主要包含的内容有SAC算法的理论分析和核心代码实现。

与许多目的是最大化累计奖励的深度强化学习算法不同,SAC算法的目的是最大化最大化熵正则化的累积奖励,这样能够鼓励智能体有更多的探索,从而达到更好的训练效果。
m a x π θ [ ∑ t γ t ( r ( S t , A t ) + α H ( π θ ( ⋅ ∣ S t ) ) ) ] {max}_{\pi_{\theta}}\left[\sum_{t}\gamma^{t}\left(r(S_{t},A_{t})+\alpha\mathcal{H}(\pi_{\theta}(\cdot|S_{t}))\right)\right] maxπθ[tγt(r(St,At)+αH(πθ(St)))]
SAC算法的目的是寻找到一个随机策略
π ∗   = arg ⁡ m a x π θ ∑ t ∣ E ( s t , α t ) ∼ ρ π θ   [ r ( s t , α t ) +   α H ( π θ ( ⋅ ∣ s t ) ) ] \pi^{\ast}\,=\arg{max}_{\pi_{\theta}}\sum_{t}\vert\mathrm{E}_{(\mathrm{s}_{t},\alpha_{t})\sim\rho\pi_{\theta}}\,\left[r(\mathrm{s}_{t},\alpha_{t})+\,\alpha\mathcal{H}(\pi_\theta(\cdot\vert\mathrm{s}_{t}))\right] π=argmaxπθtE(st,αt)ρπθ[r(st,αt)+αH(πθ(st))]
一般而言我们定义V和Q的关系为
V ^ ϕ π θ ( s t ) ≡   ∣ E a t ∼ π θ ( . ∣ s t )   [ Q ^ ϕ π θ ( s t , a t ) ] \begin{array}{l}{{\hat{V}_{\phi}^{\pi}\theta\left(\mathbf{s}_{t}\right)\equiv\ |\mathbf{E}_{\mathbf{a}_{t}}\sim\pi_{\theta}(.|\mathbf{s}_{t})\ \left[\hat{Q}_{\phi}^{\pi}\theta\left(\mathbf{s}_{t},\mathbf{a}_{t}\right)\right]}}\end{array} V^ϕπθ(st) Eatπθ(.st) [Q^ϕπθ(st,at)]
在SAC中我们使用soft update
V ^ ϕ π θ ( s t ) = E a t ∼ π θ ( . ∣ s t ) [ Q ^ ϕ π θ ( s t , a t ) ] + α H ( π θ ( . ∣ s t ) ) = E a t ∼ π θ ( . ∣ s t ) [ Q ^ ϕ π θ ( s t , a t ) ] + α E a t ∼ π θ ( . ∣ s t ) [ − log ⁡ π θ ( a t ∣ s t ) ] = E a t ∼ π θ ( . ∣ s t ) [ Q ^ ϕ π θ ( s t , a t ) − α log ⁡ π θ ( a t ∣ s t ) ] \begin{aligned} \hat{V}_{\phi}^{\pi_{\boldsymbol{\theta}}}\left(\mathbf{s}_{t}\right) &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right]+\alpha \mathcal{H}\left(\pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)\right) \\ &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right]+\alpha \mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[-\log \pi_{\boldsymbol{\theta}}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right] \\ &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\alpha \log \pi_{\boldsymbol{\theta}}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right] \end{aligned} V^ϕπθ(st)=Eatπθ(.st)[Q^ϕπθ(st,at)]+αH(πθ(.st))=Eatπθ(.st)[Q^ϕπθ(st,at)]+αEatπθ(.st)[logπθ(atst)]=Eatπθ(.st)[Q^ϕπθ(st,at)αlogπθ(atst)]
SAC有两个版本,第一版使用了Q network, V network,Policy network,熵正则化的系数为定值。第二版的SAC中将V network取消,使用了Double Q network,并且提出了能够动态调节熵正则化系数的方法。这里将先介绍第一种SAC算法,再介绍第二种SAC算法。

SAC(版本一)

V network的目标函数
J V ( ψ ) = E s t  ⁣ ∼  ⁣ D    [ 1 2 ( V ψ ( s t ) − E a t ∼ π ϕ [ Q θ ( s t , a t ) − log ⁡ π ϕ ( a t ∣ s t ) ] ) 2 ] J_{V}(\psi)=\mathbb{E}_{\mathbf{s}_{t}}\!\sim\!D\;\left[{\frac{1}{2}}\left(V_{\psi}(\mathbf{s}_{t})-\mathbb{E}_{\mathbf{a}_{t}\sim\pi_{\phi}}\left[Q_{\theta}(\mathbf{s}_{t},\mathbf{a}_{t})-\log\pi_{\phi}(\mathbf{a}_{t}|\mathbf{s}_{t})\right]\right)^{2}\right] JV(ψ)=EstD[21(Vψ(st)Eatπϕ[Qθ(st,at)logπϕ(atst)])2]
Q network的目标函数
J Q ( θ ) = E ( s t , a t ) ∼ D [ 1 2 ( Q θ ( s t , a t ) − Q ^ ( s t , a t ) ) 2 ] J_{Q}(\theta)=\mathbb{E}_{\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right) \sim \mathcal{D}}\left[\frac{1}{2}\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\hat{Q}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right)^{2}\right] JQ(θ)=E(st,at)D[21(Qθ(st,at)Q^(st,at))2]
Policy network的目标函数
J π ( ϕ ) = E S t ∼ D [ D K L ( π ϕ ( ⋅ ∣ s t ) ∣ ∣ exp ⁡ ( Q θ ( s t , ⋅ ) ) Z θ ( s t ) ) ] J_{\pi}(\phi)=\mathbb{E}_{\mathbb{S}_{t}\sim D}\left[\mathrm{D}_{\mathrm{KL}}\left(\pi_{\phi}(\cdot|\mathbf{s}_{t})\left|\right|{\frac{\exp\left(Q_{\theta}(\mathbf{s}_{t},\cdot)\right)}{Z_{\theta}(\mathbf{s}_{t})}}\right)\right] Jπ(ϕ)=EStD[DKL(πϕ(st)Zθ(st)exp(Qθ(st,)))]
初看Policy network的目标函数的表示可能会有些不太理解,其实 exp ⁡ ( Q θ ( s t , ⋅ ) ) Z θ ( s t ) \frac{\exp\left(Q_{\theta}(\mathbf{s}_{t},\cdot)\right)}{Z_{\theta}(\mathbf{s}_{t})} Zθ(st)exp(Qθ(st,))是下面的式子的解(其中 Z θ ( s t ) Z_{\theta}(\mathbf{s}_{t}) Zθ(st)用于归一化, Z ( s ) = ∑ a exp ⁡ ( 1 α Q ( s , a ) ) Z(s)=\sum_{a}\exp\left({\textstyle{\frac{1}{\alpha}}}Q(s,a)\right) Z(s)=aexp(α1Q(s,a)))
π ∗   = arg ⁡ m a x π θ ∑ t ∣ E ( s t , α t ) ∼ ρ π θ   [ r ( s t , α t ) +   α H ( π θ ( ⋅ ∣ s t ) ) ] \pi^{\ast}\,=\arg{max}_{\pi_{\theta}}\sum_{t}\vert\mathrm{E}_{(\mathrm{s}_{t},\alpha_{t})\sim\rho\pi_{\theta}}\,\left[r(\mathrm{s}_{t},\alpha_{t})+\,\alpha\mathcal{H}(\pi_\theta(\cdot\vert\mathrm{s}_{t}))\right] π=argmaxπθtE(st,αt)ρπθ[r(st,αt)+αH(πθ(st))]
如果采用的策略模型无法表达最优的策略π,我们可以让它们的KL散度最小。

SAC(版本二)

在SAC版本一中,使用了三个网络。但是其实V network和Q network本身是有联系的,所以后面在SAC第二个版本的提出中去掉了V network,使用了Double Q network来解决高估问题。并且提供了动态调节 α \alpha α的方法。一般来说,推荐使用第二个版本的SAC算法。版本二的SAC在很多方面都和SAC相似,本文重点介绍不同的方面。

自动化调节正则化参数的方法可以通过最下化下面的损失函数来实现其中 k = − d i m ( A ) k=-dim(A) k=dim(A)
J ( α ) = E a ∼ π θ [ − α log ⁡ π θ ( a ∣ s ) − α κ ] J(\alpha)=\mathbb{E}_{a\sim\pi_{\theta}}\left[-\alpha\log\pi_{\theta}(a|s)-\alpha\kappa\right] J(α)=Eaπθ[αlogπθ(as)ακ]
具体的证明有兴趣的读者可以参考SAC的论文

重参数化(Re-parameterization)

重参数化能够降低期望估计的方差并且有利于梯度的反向传播,在SAC中使用了重参数化的技巧。假设我们已经知道了动作的均值和标准差 μ θ \mu_{\theta} μθ σ θ \sigma_{\theta} σθ,我们需要令
a t = t a n h ( μ θ + ϵ ⋅ σ θ ) , ϵ ∼ N ( 0 , 1 ) a_t = tanh(\mu_{\theta}+\epsilon\cdot\sigma_{\theta}),\epsilon\sim\mathcal{N}(0,1)\qquad at=tanh(μθ+ϵσθ),ϵN(0,1)
对应的Python代码为

from torch.distributions import Normal
normal = Normal(mean, std)
z = normal.rsample()

在Pytorch中Normal有samplersample,sample是直接在定义的分布上采样,rsample是先对标准正太分布N(0,1)进行采样,然后输出:mean+std×采样值,要做重参数化推荐使用rsample。根据我个人的经验,一开始我使用的是sample但是智能体并没有很好的学习到策略,换成了rsample之后很快就完成了训练。

代码示例

Policy network

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action=1, init_w=3e-3):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3_mean = nn.Linear(128, action_dim)
        self.log_std_linear = nn.Linear(128, action_dim)
        self.max_action = max_action

        self.l3_mean.weight.data.uniform_(-init_w, init_w)
        self.l3_mean.bias.data.uniform_(-init_w, init_w)
        self.log_std_linear.weight.data.uniform_(-init_w, init_w)
        self.log_std_linear.bias.data.uniform_(-init_w, init_w)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        mean = self.l3_mean(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, -20, 2)

        return mean, log_std

    def evaluate(self, state, epsilon=1e-6):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        z = normal.rsample()
        action = torch.tanh(z)
        log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True)

        return action, log_prob

    def select_action(self, state):
        state = torch.FloatTensor(state).to(device)
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        z = normal.rsample()
        action = torch.tanh(z)
        action = action.detach().cpu().numpy()
        return action

注意,在训练中,我将环境内智能体的action范围进行了normalization,所以max_action=1。

 log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)

代码中的log_prob对应的是 log ⁡ π ( a ∣ s ) \log\pi(\mathbf{a}|\mathbf{s}) logπ(as),这行代码的理论依据为论文原文的这个公式,epsilon的添加是为了避免第二项出现无穷小。

log ⁡ π ( a ∣ s ) = log ⁡ μ ( u ∣ s ) − ∑ i = 1 D l o g ( 1 − tanh ⁡ 2 ( u i ) ) \log\pi(\mathbf{a}|\mathbf{s})=\log\mu(\mathbf{u}|\mathbf{s})-\sum_{i=1}^{D}\mathbf{log}\left(1-\operatorname{tanh}^{2}(\mathbf{u}_{i})\right) logπ(as)=logμ(us)i=1Dlog(1tanh2(ui))

Q network

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, init_w=3e-3):
        super(Critic, self).__init__()

        self.l1 = nn.Linear(state_dim + action_dim, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3 = nn.Linear(128, 1)

        self.l3.weight.data.uniform_(-init_w, init_w)
        self.l3.bias.data.uniform_(-init_w, init_w)

    def forward(self, x, u):
        x = F.relu(self.l1(torch.cat([x, u], 1)))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x

这部分和以前接触的Q network的定义并没有太多的不同

Update parameters

def update(self):

        # Sample replay buffer
        state, action, reward, next_state, done = self.replay_buffer.sample(args.batch_size)
        state = torch.FloatTensor(state).to(device)
        action = torch.FloatTensor(action).to(device)
        reward = torch.FloatTensor(reward).to(device)
        next_state = torch.FloatTensor(next_state).to(device)
        done = torch.FloatTensor(1 - done).to(device)

        next_action, next_log_prob = self.policy_network.evaluate(next_state)
        # Compute the target Q value
        target_Q_1 = self.critic_target_1(next_state, next_action)
        target_Q_2 = self.critic_target_2(next_state, next_action)
        target_Q = torch.min(target_Q_1, target_Q_2) - next_log_prob
        my_target_Q = reward.reshape((100, 1)) + (done * args.gamma * target_Q)

        # Get current Q estimate
        current_Q_1 = self.critic_1(state, action)
        current_Q_2 = self.critic_2(state, action)

        # Compute critic loss
        critic_loss_1 = F.mse_loss(current_Q_1, my_target_Q.detach())
        critic_loss_2 = F.mse_loss(current_Q_2, my_target_Q.detach())
        critic_loss = critic_loss_1 + critic_loss_2
        
        # Optimize the critic
        self.critic_optimizer_1.zero_grad()
        self.critic_optimizer_2.zero_grad()
        critic_loss.backward()
        self.critic_optimizer_1.step()
        self.critic_optimizer_2.step()

        if self.update_step % 2 == 0:
            new_action, log_prob = self.policy_network.evaluate(state)
            # Compute actor loss
            min_q = torch.min(
                self.critic_1(state, new_action),
                self.critic_2(state, new_action)
            )
            actor_loss = (log_prob - min_q).mean()
            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic_1.parameters(), self.critic_target_1.parameters()):
                target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
            for param, target_param in zip(self.critic_2.parameters(), self.critic_target_2.parameters()):
                target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
        self.update_step += 1

参数更新主要分为三个部分,第一个部分为Q network,第二部分为 Policy network, 第三部分为 α \alpha α。在上述代码中我没有实现第三部分的更新,读者如果想实现自动调节只需根据公式完成代码的编写即可。

Reference

1:Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor

2:Soft Actor-Critic Algorithms and Applications

3:Deep Reinforcement Learning Fundamentals, Research and Applications

4:From Policy Gradient to Actor-Critic methods Soft Actor Critic, ISIR

5:https://github.com/cyoon1729/Policy-Gradient-Methods

  • 2
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值