1. SAC
Soft Actor Critic (SAC) 是一种off-policy的算法,结合了随机策略优化与DDPG方法,可以用于连续或离散的动作空间。
本文对SAC的理解基于Open AI Soft Actor-Critic — Spinning Up documentation (openai.com)
本专栏相关内容
1.1 基本原理
SAC的关键是entropy regularization, 策略(policy)训练来权衡期望回报(expected return)以及熵(entropy),类似于探索-利用之间的权衡。增大熵即增大探索,防止策略过早的局部收敛。
1.1.1 entropy-regularized RL
为了解释什么是Soft Actor Critic,首先要介绍entropy-regularized RL的概念。
Entropy,即熵,表明一个变量的随机程度。举个例子,如果一个硬币的重量分布让它一直是正面朝上,那说明它的随机性不强,熵比较低;如果它的重量分布一半一半,让它的正反的结果也对半分,说明它的熵比较高。
假定是一个随机变量具有概率密度函数,的熵如下:
在entropy-regularized RL中,每一个timestep智能体获得的奖励和熵成一定的比例。
由此可以定义包含熵的, 包含所有timestep的熵,
也发生了改变,即包含熵项,但是不包含第一个timestep,
从上述定义中可以看出和之间的关系,
的贝尔曼方程可以表示为,
1.1.2 Learning Q
SAC中一共有5个网络,包含1个policy网络和4个Q-functions。Q-functions的设置和TD3的很像但是也有差别。
学习Q网络和之前的算法的基本想法一样,4个Q-function中有2个目标网络,Learning Q的目的就是为了能使评估动作状态对的Q-function输出的Q值尽可能接近目标Q(使用MSBE loss)
目标Q networks通过soft update进行更新
Similarity:
- Like in TD3, both Q-functions are learned with MSBE minimization, by regressing to a single shared target.
- Like in TD3, the shared target is computed using target Q-networks, and the target Q-networks are obtained by polyak averaging the Q-network parameters over the course of training.
- Like in TD3, the shared target makes use of the clipped double-Q trick.
Difference:
- Unlike in TD3, the target also includes a term that comes from SAC’s use of entropy regularization.
- Unlike in TD3, the next-state actions used in the target come from the current policy instead of a target policy.
- Unlike in TD3, there is no explicit target policy smoothing. TD3 trains a deterministic policy, and so it accomplishes smoothing by adding random noise to the next-state actions. SAC trains a stochastic policy, and so the noise from that stochasticity is sufficient to get a similar effect.
相同之处:使用共享的target-Q来更新Q-function,并且使用clipped double-Q的方法生成target-Q(由两个目标Q网络生成的较小的那一个用于计算target-q)
不同之处:引入了熵,只有一个policy网络,SAC是随机策略网络而非确定性网络,不需要对输出加入随机噪声
计算q-traget的时候,由于只有一个策略网络,相比于之前TD3的方法,动作-状态对的输入由当前的策略网络得到,可以近似为,
Q-function通过MSBE loss进行更新,计算的方法与TD3类似,使用clipped double-Q的方法生成target-Q,选择两个目标Q-function中输出的q-值较小的那一个值用于计算,
最终得到的Q-network的loss function如下,
1.1.3 Learning the Policy
Policy期望可以最大化未来的期望回报并且具有期望的熵,基于此,应该最大化,
选择action:
得到估计当前action 的Q-value,两个Q-function中选择较小的那个作为计算loss的Q-value,来计算loss,
1.2 代码说明
说明:代码来源于参考链接【2】
1.2.1 搭建网络
# policy net
class PolicyNet(nn.Module):
def __init__(self, n_states, n_actions, n_hiddens):
super(PolicyNet, self).__init__()
layers = []
layer_shape = [n_states] + list(n_hiddens) + [n_actions]
activation = nn.ReLU
for i in range(len(layer_shape)-1):
layers += [nn.Linear(layer_shape[i], layer_shape[i+1]), activation()]
self.net = nn.Sequential(*layers)
def forward(self, x):
x = self.net(x)
return x
# critic net
class CriticNet(nn.Module):
def __init__(self, n_states, n_actions, n_hiddens):
super(CriticNet, self).__init__()
layers =[]
layer_shape = [n_states] + list(n_hiddens) + [n_actions]
activation = nn.ReLU
for i in range(len(layer_shape) - 1):
layers += [nn.Linear(layer_shape[i], layer_shape[i+1]), activation()]
self.net = nn.Sequential(*layers)
def forward(self, x):
x = self.net(x)
return x
# policy net
self.actor = PolicyNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
# Q1, Q2, target Q1, target Q2
self.critic_1 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
self.critic_2 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
self.target_critic_1 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
self.target_critic_2 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
# initialize target Q net
self.target_critic_1.load_state_dict(self.critic_1.state_dict())
self.target_critic_2.load_state_dict(self.critic_2.state_dict())
1.2.2 更新Q-functions
通过找到两个q-target中最小的值,来计算td-target,
def calculate_target(self, r, s_, dones):
"""
calculate the target q-value Q'(r_, mu(s_))
:param r: rewards
:param s_: next_states
:param dones:
:return: td_target = (q_value-q_target)^2
"""
next_probs = self.actor(s_)
next_log_probs = torch.log(next_probs + 1e-8)
# calculate entropy
entropy = -torch.sum(next_probs)
# calculate target Q1 & Q2
tar_q1_value = self.target_critic_1(s_)
tar_q2_value = self.target_critic_2(s_)
min_tar_q = torch.sum(next_probs*torch.min(tar_q1_value, tar_q2_value), dim=1, keepdim=True)
td_target = r + self.gamma * (1-dones) * (min_tar_q - self.log_alpha.exp()*entropy)
return td_target
得到td_target之后Q1和Q2的value分别计算td_error和loss,对Q-functions进行更新。
# --------------------------------- #
# update 2 critic networks, compute loss
# --------------------------------- #
critic_1_qvalue = self.critic_1(states).gather(1, actions)
critic_1_loss = torch.mean(F.mse_loss(critic_1_qvalue, td_target.detach()))
critic_2_qvalue = self.critic_2(states).gather(1, actions)
critic_2_loss = torch.mean(F.mse_loss(critic_2_qvalue, td_target.detach()))
# optimize c1
self.optimizer_c1.zero_grad()
critic_1_loss.backward()
self.optimizer_c1.step()
# optimize c2
self.optimizer_c2.zero_grad()
critic_2_loss.backward()
self.optimizer_c2.step()
soft update更新Target Q-functions
# soft update target Q1 & Q2
self.soft_update(self.critic_1, self.target_critic_1)
self.soft_update(self.critic_2, self.target_critic_2)
1.2.3 更新policy
# --------------------------------- #
# update policy network, compute loss
# --------------------------------- #
probs = self.actor(states)
log_probs = torch.log(probs + 1e-8)
entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True)
# q-functions predict the current q-value
q1_value = self.critic_1(states)
q2_value = self.critic_2(states)
min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value), dim=1, keepdim=True)
# loss
actor_loss = torch.mean(-self.log_alpha.exp()*entropy - min_qvalue)
# optimize actor
self.optimizer_a.zero_grad()
actor_loss.backward()
self.optimizer_a.step()
1.3 伪代码
1.4 完整代码
说明:代码来源于参考链接【2】
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import collections
import random
class SAC_dis:
def __init__(self,
env,
n_states,
n_hiddens,
n_actions,
lr_a,
lr_c,
lr_alpha,
target_entropy,
tau,
gamma,
device='cpu'):
self.env = env
self.n_states = n_states
self.n_actions = n_actions
self.n_hiddens = n_hiddens
self.lr_a = lr_a
self.lr_c = lr_c
self.lr_alpha = lr_alpha
self.target_entropy = target_entropy
self.tau = tau
self.gamma = gamma
self.device = device
# policy net
self.actor = PolicyNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
# Q1, Q2, target Q1, target Q2
self.critic_1 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
self.critic_2 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
self.target_critic_1 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
self.target_critic_2 = CriticNet(self.n_states, self.n_actions, self.n_hiddens).to(device)
# initialize target Q net
self.target_critic_1.load_state_dict(self.critic_1.state_dict())
self.target_critic_2.load_state_dict(self.critic_2.state_dict())
# initialize optimizer
self.optimizer_a = torch.optim.Adam(self.actor.parameters(), lr=lr_a)
self.optimizer_c1 = torch.optim.Adam(self.critic_1.parameters(), lr=lr_c)
self.optimizer_c2 = torch.optim.Adam(self.critic_2.parameters(), lr=lr_c)
# initialize alpha
self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
self.log_alpha.requires_grad = True
self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=lr_alpha)
def select_action(self, state):
# numpy[n_states] --> tensor [1, n_states]
state = torch.tensor(state[np.newaxis, :], dtype=torch.float).to(self.device)
# calculate action probabilities
probs = self.actor(state)
# sample action from Categorical distribution
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample().item()
return action
def calculate_target(self, r, s_, dones):
"""
calculate the target q-value Q'(r_, mu(s_))
:param r: rewards
:param s_: next_states
:param dones:
:return: td_target = (q_value-q_target)^2
"""
next_probs = self.actor(s_)
next_log_probs = torch.log(next_probs + 1e-8)
# calculate entropy
entropy = -torch.sum(next_probs)
# calculate target Q1 & Q2
tar_q1_value = self.target_critic_1(s_)
tar_q2_value = self.target_critic_2(s_)
min_tar_q = torch.sum(next_probs*torch.min(tar_q1_value, tar_q2_value), dim=1, keepdim=True)
td_target = r + self.gamma * (1-dones) * (min_tar_q - self.log_alpha.exp()*entropy)
return td_target
def soft_update(self, net, target_net):
for param_target, param in zip(target_net.parameters(), net.parameters()):
param_target.data.copy_(param_target.data*(1-self.tau) + param.data*self.tau)
def update(self, transition_dict):
# extract (s, a, r, s_, dones)
states = torch.tensor(transition_dict['s'], dtype=torch.float).to(self.device)
actions = torch.tensor(transition_dict['a']).view(-1, 1).to(self.device)
rewards = torch.tensor(transition_dict['r'], dtype=torch.float).view(-1, 1).to(self.device)
next_states = torch.tensor(transition_dict['s_'], dtype=torch.float).to(self.device)
dones = torch.tensor(transition_dict['done'], dtype=torch.float).view(-1, 1).to(self.device)
td_target = self.calculate_target(rewards, next_states, dones)
# --------------------------------- #
# update 2 critic networks, compute loss
# --------------------------------- #
critic_1_qvalue = self.critic_1(states).gather(1, actions)
critic_1_loss = torch.mean(F.mse_loss(critic_1_qvalue, td_target.detach()))
critic_2_qvalue = self.critic_2(states).gather(1, actions)
critic_2_loss = torch.mean(F.mse_loss(critic_2_qvalue, td_target.detach()))
# optimize c1
self.optimizer_c1.zero_grad()
critic_1_loss.backward()
self.optimizer_c1.step()
# optimize c2
self.optimizer_c2.zero_grad()
critic_2_loss.backward()
self.optimizer_c2.step()
# --------------------------------- #
# update policy network, compute loss
# --------------------------------- #
probs = self.actor(states)
log_probs = torch.log(probs + 1e-8)
entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True)
# q-functions predict the current q-value
q1_value = self.critic_1(states)
q2_value = self.critic_2(states)
min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value), dim=1, keepdim=True)
# loss
actor_loss = torch.mean(-self.log_alpha.exp()*entropy - min_qvalue)
# optimize actor
self.optimizer_a.zero_grad()
actor_loss.backward()
self.optimizer_a.step()
# --------------------------------- #
# update alpha
# --------------------------------- #
alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())
# policy gradient
self.log_alpha_optimizer.zero_grad()
alpha_loss.backward()
self.log_alpha_optimizer.step()
# soft update target Q1 & Q2
self.soft_update(self.critic_1, self.target_critic_1)
self.soft_update(self.critic_2, self.target_critic_2)
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = collections.deque(maxlen=capacity)
def add(self, s, a, r, s_, done):
self.buffer.append((s, a, r, s_, done))
def sample(self, batch_size):
# select randomly
transitions = random.sample(self.buffer, batch_size)
s, a, r, s_, done = zip(*transitions)
return np.array(s), a, r, np.array(s_), done
def size(self):
return len(self.buffer)
# policy net
class PolicyNet(nn.Module):
def __init__(self, n_states, n_actions, n_hiddens):
super(PolicyNet, self).__init__()
layers = []
layer_shape = [n_states] + list(n_hiddens) + [n_actions]
activation = nn.ReLU
for i in range(len(layer_shape)-1):
layers += [nn.Linear(layer_shape[i], layer_shape[i+1]), activation()]
self.net = nn.Sequential(*layers)
def forward(self, x):
x = self.net(x)
return x
# critic net
class CriticNet(nn.Module):
def __init__(self, n_states, n_actions, n_hiddens):
super(CriticNet, self).__init__()
layers =[]
layer_shape = [n_states] + list(n_hiddens) + [n_actions]
activation = nn.ReLU
for i in range(len(layer_shape) - 1):
layers += [nn.Linear(layer_shape[i], layer_shape[i+1]), activation()]
self.net = nn.Sequential(*layers)
def forward(self, x):
x = self.net(x)
return x
参考链接
【1】Soft Actor-Critic — Spinning Up documentation (openai.com)
【2】【深度强化学习】(7) SAC 模型解析,附Pytorch完整代码_sac算法-CSDN博客