【强化学习笔记】(5) SAC

1. SAC

Soft Actor Critic (SAC) 是一种off-policy的算法,结合了随机策略优化与DDPG方法,可以用于连续或离散的动作空间。

本文对SAC的理解基于Open AI Soft Actor-Critic — Spinning Up documentation (openai.com)

本专栏相关内容

① 【强化学习笔记】(3) DDPG-CSDN博客

② 【强化学习笔记】(4) TD3-CSDN博客


1.1 基本原理

SAC的关键是entropy regularization, 策略(policy)训练来权衡期望回报(expected return)以及熵(entropy),类似于探索-利用之间的权衡。增大熵即增大探索,防止策略过早的局部收敛。

1.1.1 entropy-regularized RL

为了解释什么是Soft Actor Critic,首先要介绍entropy-regularized RL的概念。

Entropy,即,表明一个变量的随机程度。举个例子,如果一个硬币的重量分布让它一直是正面朝上,那说明它的随机性不强,熵比较低;如果它的重量分布一半一半,让它的正反的结果也对半分,说明它的熵比较高。

假定x是一个随机变量具有概率密度函数PxH如下:

H(P)=\underset{x \sim P}{\mathrm{E}}[-\log P(x)]

entropy-regularized RL中,每一个timestep智能体获得的奖励和熵成一定的比例。

由此可以定义包含熵的V^{\pi }, 包含所有timestep的熵,

V^\pi(s)=\underset{\tau \sim \pi}{\mathrm{E}}\left[\sum_{t=0}^{\infty} \gamma^t\left(R\left(s_t, a_t, s_{t+1}\right)+\alpha H\left(\pi\left(\cdot \mid s_t\right)\right)\right) \mid s_0=s\right]

Q^{\pi }也发生了改变,即包含熵项,但是不包含第一个timestep

Q^\pi(s, a)=\underset{\tau \sim \pi}{\mathrm{E}}\left[\sum_{t=0}^{\infty} \gamma^t R\left(s_t, a_t, s_{t+1}\right)+\alpha \sum_{t=1}^{\infty} \gamma^t H\left(\pi\left(\cdot \mid s_t\right)\right) \mid s_0=s, a_0=a\right]

从上述定义中可以看出V^{\pi }Q^{\pi }之间的关系

V^\pi(s)=\underset{a \sim \pi}{\mathrm{E}}\left[Q^\pi(s, a)\right]+\alpha H(\pi(\cdot \mid s))

Q^{\pi }的贝尔曼方程可以表示为,

\begin{aligned} Q^\pi(s, a) & =\underset{\substack{s^{\prime} \sim P \\ a^{\prime} \sim \pi}}{\mathrm{E}}\left[R\left(s, a, s^{\prime}\right)+\gamma\left(Q^\pi\left(s^{\prime}, a^{\prime}\right)+\alpha H\left(\pi\left(\cdot \mid s^{\prime}\right)\right)\right)\right] \\ & =\underset{s^{\prime} \sim P}{\mathrm{E}}\left[R\left(s, a, s^{\prime}\right)+\gamma V^\pi\left(s^{\prime}\right)\right] . \end{aligned}

1.1.2 Learning Q

SAC中一共有5个网络,包含1个policy网络4个Q-functions。Q-functions的设置和TD3的很像但是也有差别。

学习Q网络和之前的算法的基本想法一样,4个Q-function中有2个目标网络,Learning Q的目的就是为了能使评估动作状态对的Q-function输出的Q值尽可能接近目标Q(使用MSBE loss)

目标Q networks通过soft update进行更新

Similarity:

  • Like in TD3, both Q-functions are learned with MSBE minimization, by regressing to a single shared target.
  • Like in TD3, the shared target is computed using target Q-networks, and the target Q-networks are obtained by polyak averaging the Q-network parameters over the course of training.
  • Like in TD3, the shared target makes use of the clipped double-Q trick.

Difference:

  • Unlike in TD3, the target also includes a term that comes from SAC’s use of entropy regularization.
  • Unlike in TD3, the next-state actions used in the target come from the current policy instead of a target policy.
  • Unlike in TD3, there is no explicit target policy smoothing. TD3 trains a deterministic policy, and so it accomplishes smoothing by adding random noise to the next-state actions. SAC trains a stochastic policy, and so the noise from that stochasticity is sufficient to get a similar effect.

相同之处:使用共享的target-Q来更新Q-function,并且使用clipped double-Q的方法生成target-Q(由两个目标Q网络生成的较小的那一个用于计算target-q)

不同之处:引入了,只有一个policy网络,SAC是随机策略网络而非确定性网络,不需要对输出加入随机噪声

计算q-traget的时候,由于只有一个策略网络,相比于之前TD3的方法,动作-状态对的输入由当前的策略网络得到,Q^{\pi }可以近似为,

Q^\pi(s, a) \approx r+\gamma\left(Q^\pi\left(s^{\prime}, \tilde{a}^{\prime}\right)-\alpha \log \pi\left(\tilde{a}^{\prime} \mid s^{\prime}\right)\right), \quad \tilde{a}^{\prime} \sim \pi\left(\cdot \mid s^{\prime}\right)

Q-function通过MSBE loss进行更新,计算的方法与TD3类似,使用clipped double-Q的方法生成target-Q,选择两个目标Q-function中输出的q-值较小的那一个值用于计算

y\left(r, s^{\prime}, d\right)=r+\gamma(1-d)\left(\min _{j=1,2} Q_{\phi_{\text {targ }, j}}\left(s^{\prime}, \tilde{a}^{\prime}\right)-\alpha \log \pi_\theta\left(\tilde{a}^{\prime} \mid s^{\prime}\right)\right), \quad \tilde{a}^{\prime} \sim \pi_\theta\left(\cdot \mid s^{\prime}\right)

最终得到的Q-network的loss function如下,

L\left(\phi_i, \mathcal{D}\right)=\underset{\left(s, a, r, s^{\prime}, d\right) \sim \mathcal{D}}{\mathrm{E}}\left[\left(Q_{\phi_i}(s, a)-y\left(r, s^{\prime}, d\right)\right)^2\right]

1.1.3 Learning the Policy

Policy期望可以最大化未来的期望回报并且具有期望的,基于此,应该最大化V^{\pi }

\begin{aligned} V^\pi(s) & =\underset{a \sim \pi}{\mathrm{E}}\left[Q^\pi(s, a)\right]+\alpha H(\pi(\cdot \mid s)) \\ & =\underset{a \sim \pi}{\mathrm{E}}\left[Q^\pi(s, a)-\alpha \log \pi(a \mid s)\right] \end{aligned}

选择action:

\tilde{a}_\theta(s, \xi)=\tanh \left(\mu_\theta(s)+\sigma_\theta(s) \odot \xi\right), \quad \xi \sim \mathcal{N}(0, I)

得到估计当前action 的Q-value,两个Q-function中选择较小的那个作为计算loss的Q-value,来计算loss,

\underset{a \sim \pi_\theta}{\mathrm{E}}\left[Q^{\pi_\theta}(s, a)-\alpha \log \pi_\theta(a \mid s)\right]=\underset{\xi \sim \mathcal{N}}{\mathrm{E}}\left[Q^{\pi_\theta}\left(s, \tilde{a}_\theta(s, \xi)\right)-\alpha \log \pi_\theta\left(\tilde{a}_\theta(s, \xi) \mid s\right)\right]

\max _\theta \underset{\substack{\xi \sim \mathcal{N} \\ \xi \sim \mathcal{N}}}{\mathrm{E}}\left[\min _{j=1,2} Q_{\phi_j}\left(s, \tilde{a}_\theta(s, \xi)\right)-\alpha \log \pi_\theta\left(\tilde{a}_\theta(s, \xi) \mid s\right)\right]

1.2 代码说明

说明:代码来源于参考链接【2】

1.2.1 搭建网络

# policy net
class PolicyNet(nn.Module):
    def __init__(self, n_states, n_actions, n_hiddens):
        super(PolicyNet, self).__init__()
        layers = []
        layer_shape = [n_states] + list(n_hiddens) + [n_actions]
        activation = nn.ReLU
        for i in range(len(layer_shape)-1):
            layers += [nn.Linear(layer_shape[i], layer_shape[i+1]), activation()]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = self.net(x)
        return x


# critic net
class CriticNet(nn.Module):
    def __init__(self, n_states, n_actions, n_hiddens):
        super(CriticNet, self).__init__()

        layers =[]
        layer_shape = [n_states] + list(n_hiddens) + [n_actions]
        activation = nn.ReLU
        for i in range(len(layer_shape) - 1):
            layers += [nn.Linear(layer_shape[i], layer_shape[i+1]), activation()]

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = self.net(x)
        return x
        # policy net
        self.actor = PolicyNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
        # Q1, Q2, target Q1, target Q2
        self.critic_1 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
        self.critic_2 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
        self.target_critic_1 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
        self.target_critic_2 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)

        # initialize target Q net
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())

1.2.2 更新Q-functions

 通过找到两个q-target中最小的值,来计算td-target,

    def calculate_target(self, r, s_, dones):
        """
        calculate the target q-value Q'(r_, mu(s_))
        :param r: rewards
        :param s_: next_states
        :param dones:
        :return: td_target = (q_value-q_target)^2
        """
        next_probs = self.actor(s_)
        next_log_probs = torch.log(next_probs + 1e-8)
        # calculate entropy
        entropy = -torch.sum(next_probs)
        # calculate target Q1 & Q2
        tar_q1_value = self.target_critic_1(s_)
        tar_q2_value = self.target_critic_2(s_)
        min_tar_q = torch.sum(next_probs*torch.min(tar_q1_value, tar_q2_value), dim=1, keepdim=True)

        td_target = r + self.gamma * (1-dones) * (min_tar_q - self.log_alpha.exp()*entropy)
        return td_target

 得到td_target之后Q1和Q2的value分别计算td_error和loss,对Q-functions进行更新。

        # --------------------------------- #
        # update 2 critic networks, compute loss
        # --------------------------------- #
        critic_1_qvalue = self.critic_1(states).gather(1, actions)
        critic_1_loss = torch.mean(F.mse_loss(critic_1_qvalue, td_target.detach()))
        critic_2_qvalue = self.critic_2(states).gather(1, actions)
        critic_2_loss = torch.mean(F.mse_loss(critic_2_qvalue, td_target.detach()))

        # optimize c1
        self.optimizer_c1.zero_grad()
        critic_1_loss.backward()
        self.optimizer_c1.step()
        # optimize c2
        self.optimizer_c2.zero_grad()
        critic_2_loss.backward()
        self.optimizer_c2.step()

soft update更新Target Q-functions

        # soft update target Q1 & Q2
        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)

1.2.3 更新policy

        # --------------------------------- #
        # update policy network, compute loss
        # --------------------------------- #
        probs = self.actor(states)
        log_probs = torch.log(probs + 1e-8)
        entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True)

        # q-functions predict the current q-value
        q1_value = self.critic_1(states)
        q2_value = self.critic_2(states)
        min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value), dim=1, keepdim=True)

        # loss
        actor_loss = torch.mean(-self.log_alpha.exp()*entropy - min_qvalue)
        # optimize actor
        self.optimizer_a.zero_grad()
        actor_loss.backward()
        self.optimizer_a.step()

1.3 伪代码

1.4 完整代码

说明:代码来源于参考链接【2】

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import collections
import random


class SAC_dis:
    def __init__(self,
                 env,
                 n_states,
                 n_hiddens,
                 n_actions,
                 lr_a,
                 lr_c,
                 lr_alpha,
                 target_entropy,
                 tau,
                 gamma,
                 device='cpu'):

        self.env = env
        self.n_states = n_states
        self.n_actions = n_actions
        self.n_hiddens = n_hiddens
        self.lr_a = lr_a
        self.lr_c = lr_c
        self.lr_alpha = lr_alpha
        self.target_entropy = target_entropy
        self.tau = tau
        self.gamma = gamma
        self.device = device

        # policy net
        self.actor = PolicyNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
        # Q1, Q2, target Q1, target Q2
        self.critic_1 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
        self.critic_2 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
        self.target_critic_1 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
        self.target_critic_2 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)

        # initialize target Q net
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())

        # initialize optimizer
        self.optimizer_a = torch.optim.Adam(self.actor.parameters(), lr=lr_a)
        self.optimizer_c1 = torch.optim.Adam(self.critic_1.parameters(), lr=lr_c)
        self.optimizer_c2 = torch.optim.Adam(self.critic_2.parameters(), lr=lr_c)

        # initialize alpha
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        self.log_alpha.requires_grad = True
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=lr_alpha)

    def select_action(self, state):
        # numpy[n_states] --> tensor [1, n_states]
        state = torch.tensor(state[np.newaxis, :], dtype=torch.float).to(self.device)
        # calculate action probabilities
        probs = self.actor(state)
        # sample action from Categorical distribution
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample().item()
        return action

    def calculate_target(self, r, s_, dones):
        """
        calculate the target q-value Q'(r_, mu(s_))
        :param r: rewards
        :param s_: next_states
        :param dones:
        :return: td_target = (q_value-q_target)^2
        """
        next_probs = self.actor(s_)
        next_log_probs = torch.log(next_probs + 1e-8)
        # calculate entropy
        entropy = -torch.sum(next_probs)
        # calculate target Q1 & Q2
        tar_q1_value = self.target_critic_1(s_)
        tar_q2_value = self.target_critic_2(s_)
        min_tar_q = torch.sum(next_probs*torch.min(tar_q1_value, tar_q2_value), dim=1, keepdim=True)

        td_target = r + self.gamma * (1-dones) * (min_tar_q - self.log_alpha.exp()*entropy)
        return td_target

    def soft_update(self, net, target_net):
        for param_target, param in zip(target_net.parameters(), net.parameters()):
            param_target.data.copy_(param_target.data*(1-self.tau) + param.data*self.tau)

    def update(self, transition_dict):
        # extract (s, a, r, s_, dones)
        states = torch.tensor(transition_dict['s'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['a']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['r'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['s_'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['done'], dtype=torch.float).view(-1, 1).to(self.device)

        td_target = self.calculate_target(rewards, next_states, dones)

        # --------------------------------- #
        # update 2 critic networks, compute loss
        # --------------------------------- #
        critic_1_qvalue = self.critic_1(states).gather(1, actions)
        critic_1_loss = torch.mean(F.mse_loss(critic_1_qvalue, td_target.detach()))
        critic_2_qvalue = self.critic_2(states).gather(1, actions)
        critic_2_loss = torch.mean(F.mse_loss(critic_2_qvalue, td_target.detach()))

        # optimize c1
        self.optimizer_c1.zero_grad()
        critic_1_loss.backward()
        self.optimizer_c1.step()
        # optimize c2
        self.optimizer_c2.zero_grad()
        critic_2_loss.backward()
        self.optimizer_c2.step()

        # --------------------------------- #
        # update policy network, compute loss
        # --------------------------------- #
        probs = self.actor(states)
        log_probs = torch.log(probs + 1e-8)
        entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True)

        # q-functions predict the current q-value
        q1_value = self.critic_1(states)
        q2_value = self.critic_2(states)
        min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value), dim=1, keepdim=True)

        # loss
        actor_loss = torch.mean(-self.log_alpha.exp()*entropy - min_qvalue)
        # optimize actor
        self.optimizer_a.zero_grad()
        actor_loss.backward()
        self.optimizer_a.step()

        # --------------------------------- #
        # update alpha
        # --------------------------------- #

        alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())
        # policy gradient
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        # soft update target Q1 & Q2
        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)


class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def add(self, s, a, r, s_, done):
        self.buffer.append((s, a, r, s_, done))

    def sample(self, batch_size):
        # select randomly
        transitions = random.sample(self.buffer, batch_size)
        s, a, r, s_, done = zip(*transitions)
        return np.array(s), a, r, np.array(s_), done

    def size(self):
        return len(self.buffer)


# policy net
class PolicyNet(nn.Module):
    def __init__(self, n_states, n_actions, n_hiddens):
        super(PolicyNet, self).__init__()
        layers = []
        layer_shape = [n_states] + list(n_hiddens) + [n_actions]
        activation = nn.ReLU
        for i in range(len(layer_shape)-1):
            layers += [nn.Linear(layer_shape[i], layer_shape[i+1]), activation()]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = self.net(x)
        return x


# critic net
class CriticNet(nn.Module):
    def __init__(self, n_states, n_actions, n_hiddens):
        super(CriticNet, self).__init__()

        layers =[]
        layer_shape = [n_states] + list(n_hiddens) + [n_actions]
        activation = nn.ReLU
        for i in range(len(layer_shape) - 1):
            layers += [nn.Linear(layer_shape[i], layer_shape[i+1]), activation()]

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = self.net(x)
        return x



参考链接

【1】Soft Actor-Critic — Spinning Up documentation (openai.com)

【2】【深度强化学习】(7) SAC 模型解析,附Pytorch完整代码_sac算法-CSDN博客

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值