强化学习经典算法笔记(十四):双延迟深度确定性策略梯度算法TD3的PyTorch实现
TD3算法简介
TD3是Twin Delayed Deep Deterministic policy gradient algorithm的简称,双延迟深度确定性策略梯度。从名字看出,TD3算法是DDPG的改进版本。TD3算法来自论文
Addressing Function Approximation Error in Actor-Critic Methods。
TD3相对于DDPG,主要采用了以下重要改进。
- Double network
- Critic学习改进
- Actor学习改进
- target policy smoothing regularization
更详细的介绍请参考
https://zhuanlan.zhihu.com/p/111334500
https://zhuanlan.zhihu.com/p/88446488
详细介绍
解决两个问题,一个是过估计,overestimate,另一个是高方差现象,high variance。
对状态价值的过高估计是Value based方法经常遇到的问题。在Qnetwork不成熟时,对状态的估计有误差,对Q值取最大化操作时,会高于真实的最大Q值。累积下来的过高估计可能会使得算法陷入次优策略中,导致发散等行为。
TD3论文发现Actor-critic算法中也会出现overestimate现象。
解决过高估计问题的办法,文中提到了两个。一个是采取Double DQN的做法,使用target Q network和main Q network分别进行状态价值估计和选取动作,将两者解耦。另一个是Double Q-learning,即采用两个独立的Critic,分别对价值进行估计,取最小值。这个方法带来的高方差问题可以用更新时对梯度进行裁剪来改善。
y
=
r
t
+
γ
m
i
n
i
=
1
,
2
Q
θ
i
′
(
s
′
,
π
ϕ
1
(
s
′
)
)
y=r_t + \gamma min_{i=1,2} Q_{\theta '_i}(s',\pi_{\phi_1}(s'))
y=rt+γmini=1,2Qθi′(s′,πϕ1(s′))
解决高方差问题,文中采用了三个办法。
一个是target network。自从DQN起就在使用,通过降低critic的更新频率来降低方差。
第二个是降低Actor的更新频率,叫做Delaying policy updates。也就是将值函数和策略函数解耦。
第三个是target policy smoothing regularization,即一种正则化方法解决determinstic policy可能overfitting的问题。直观地讲,就是在计算Q值更新目标时,采用如下方式:
a
~
←
π
ϕ
′
(
s
′
)
+
ϵ
,
ϵ
∼
c
l
i
p
(
N
(
0
,
σ
~
)
,
−
c
,
c
)
\tilde{a} \leftarrow \pi_{\phi'}(s')+\epsilon,\quad \epsilon \sim clip(N(0,\tilde{\sigma}),-c,c)
a~←πϕ′(s′)+ϵ,ϵ∼clip(N(0,σ~),−c,c)
算法流程图
算法实现
import argparse
from collections import namedtuple
from itertools import count
import os, sys, random
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
from tensorboardX import SummaryWriter
device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='train', type=str) # mode = 'train' or 'test'
parser.add_argument("--env_name", default="LunarLanderContinuous-v2") # OpenAI gym environment name, BipedalWalker-v2 Pendulum-v0
parser.add_argument('--tau', default=0.05, type=float) # target smoothing coefficient
parser.add_argument('--target_update_interval', default=1, type=int)
parser.add_argument('--test_episode', default=50, type=int)
parser.add_argument('--epoch', default=10, type=int) # buffer采样的数据训练几次
parser.add_argument('--learning_rate', default=3e-4, type=float)
parser.add_argument('--gamma', default=0.99, type=int) # discounted factor
parser.add_argument('--capacity', default=50000, type=int) # replay buffer size
parser.add_argument('--num_episode', default=1000, type=int) # num of episodes in training
parser.add_argument('--batch_size', default=100, type=int) # mini batch size
parser.add_argument('--seed', default=True, type=bool)
parser.add_argument('--random_seed', default=9527, type=int)
# optional parameters
# parser.add_argument('--num_hidden_layers', default=2, type=int)
# parser.add_argument('--sample_frequency', default=256, type=int)
# parser.add_argument('--activation', default='Relu', type=str)
parser.add_argument('--render', default=False, type=bool) # show UI or not
parser.add_argument('--log_interval', default=50, type=int) # 每50episode保存一次模型
parser.add_argument('--load', default=False, type=bool) # 训练前是否读取模型
parser.add_argument('--render_interval', default=100, type=int) # after render_interval, the env.render() will work
parser.add_argument('--policy_noise', default=0.2, type=float) # 动作向量的噪声扰动的方差
parser.add_argument('--noise_clip', default=0.5, type=float)
parser.add_argument('--policy_delay', default=2, type=int)
parser.add_argument('--exploration_noise', default=0.1, type=float)
parser.add_argument('--max_frame', default=200, type=int)
parser.add_argument('--print_log', default=5, type=int)
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
script_name = os.path.basename(__file__)
env = gym.make(args.env_name)
env = env.unwrapped
if args.seed:
env.seed(args.random_seed)
torch.manual_seed(args.random_seed)
np.random.seed(args.random_seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0]) # 动作取值上界
min_Val = torch.tensor(1e-7).float().to(device) # min value
directory = './exp' + script_name + args.env_name +'./'
'''
Implementation of TD3 with pytorch
Original paper: https://arxiv.org/abs/1802.09477
Not the author's implementation !
'''
class Replay_buffer():
'''
Code based on:
https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py
Expects tuples of (state, next_state, action, reward, done)
'''
def __init__(self, max_size=args.capacity):
self.storage = []
self.max_size = max_size
self.ptr = 0
def push(self, data):
if len(self.storage) == self.max_size:
self.storage[int(self.ptr)] = data
self.ptr = (self.ptr + 1) % self.max_size
else:
self.storage.append(data)
def sample(self, batch_size):
ind = np.random.randint(0, len(self.storage), size=batch_size)
x, y, u, r, d = [], [], [], [], []
for i in ind:
X, Y, U, R, D = self.storage[i]
x.append(np.array(X, copy=False))
y.append(np.array(Y, copy=False))
u.append(np.array(U, copy=False))
r.append(np.array(R, copy=False))
d.append(np.array(D, copy=False))
return np.array(x), np.array(y), np.array(u), np.array(r).reshape(-1, 1), np.array(d).reshape(-1, 1)
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, action_dim)
self.max_action = max_action
def forward(self, state):
a = F.relu(self.fc1(state))
a = F.relu(self.fc2(a))
a = torch.tanh(self.fc3(a)) * self.max_action
return a
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
def forward(self, state, action):
state_action = torch.cat([state, action], 1)
q = F.relu(self.fc1(state_action))
q = F.relu(self.fc2(q))
q = self.fc3(q)
return q
class TD3():
def __init__(self, state_dim, action_dim, max_action):
self.lr_actor = args.learning_rate
self.lr_critic = args.learning_rate
self.betas = (0.9,0.999)
# 6个网络
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.critic_1 = Critic(state_dim, action_dim).to(device)
self.critic_1_target = Critic(state_dim, action_dim).to(device)
self.critic_2 = Critic(state_dim, action_dim).to(device)
self.critic_2_target = Critic(state_dim, action_dim).to(device)
# 优化器
self.actor_optimizer = optim.Adam(self.actor.parameters(),lr=self.lr_actor,betas=self.betas)
self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(),lr=self.lr_critic,betas=self.betas)
self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(),lr=self.lr_critic,betas=self.betas)
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_1_target.load_state_dict(self.critic_1.state_dict())
self.critic_2_target.load_state_dict(self.critic_2.state_dict())
self.max_action = max_action
self.memory = Replay_buffer(args.capacity)
self.writer = SummaryWriter(directory)
self.num_critic_update_iteration = 0
self.num_actor_update_iteration = 0
self.num_training = 0
def select_action(self, state):
state = torch.tensor(state.reshape(1, -1)).float().to(device)
return self.actor(state).cpu().data.numpy().flatten()
def update(self, epoch):
# if self.num_training % 500 == 0:
# print("====================================")
# print("model has been trained for {} times...".format(self.num_training))
# print("====================================")
for i in range(epoch):
x, y, u, r, d = self.memory.sample(args.batch_size)
state = torch.FloatTensor(x).to(device)
action = torch.FloatTensor(u).to(device)
next_state = torch.FloatTensor(y).to(device)
done = torch.FloatTensor(d).to(device)
reward = torch.FloatTensor(r).to(device)
# Select next action according to target policy:
noise = torch.ones_like(action).data.normal_(0, args.policy_noise).to(device)
noise = noise.clamp(-args.noise_clip, args.noise_clip)
next_action = (self.actor_target(next_state) + noise)
next_action = next_action.clamp(-self.max_action, self.max_action)
# Compute target Q-value:
target_Q1 = self.critic_1_target(next_state, next_action)
target_Q2 = self.critic_2_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + ((1 - done) * args.gamma * target_Q).detach()
# Optimize Critic 1:
current_Q1 = self.critic_1(state, action)
loss_Q1 = F.mse_loss(current_Q1, target_Q)
self.critic_1_optimizer.zero_grad()
loss_Q1.backward()
self.critic_1_optimizer.step()
self.writer.add_scalar('Loss/Q1_loss', loss_Q1, global_step=self.num_critic_update_iteration)
# Optimize Critic 2:
current_Q2 = self.critic_2(state, action)
loss_Q2 = F.mse_loss(current_Q2, target_Q)
self.critic_2_optimizer.zero_grad()
loss_Q2.backward()
self.critic_2_optimizer.step()
self.writer.add_scalar('Loss/Q2_loss', loss_Q2, global_step=self.num_critic_update_iteration)
# Delayed policy updates:
if i % args.policy_delay == 0:
# Compute actor loss:
actor_loss = - self.critic_1(state, self.actor(state)).mean()#随着更新的进行Q1和Q2两个网络,将会变得越来越像。所以用Q1还是Q2,还是两者都用,对于actor的问题不大。
# Optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.writer.add_scalar('Loss/actor_loss', actor_loss, global_step=self.num_actor_update_iteration)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(((1- args.tau) * target_param.data) + args.tau * param.data)
for param, target_param in zip(self.critic_1.parameters(), self.critic_1_target.parameters()):
target_param.data.copy_(((1 - args.tau) * target_param.data) + args.tau * param.data)
for param, target_param in zip(self.critic_2.parameters(), self.critic_2_target.parameters()):
target_param.data.copy_(((1 - args.tau) * target_param.data) + args.tau * param.data)
self.num_actor_update_iteration += 1
self.num_critic_update_iteration += 1
self.num_training += 1
def save(self):
torch.save(self.actor.state_dict(), directory+'actor.pth')
torch.save(self.actor_target.state_dict(), directory+'actor_target.pth')
torch.save(self.critic_1.state_dict(), directory+'critic_1.pth')
torch.save(self.critic_1_target.state_dict(), directory+'critic_1_target.pth')
torch.save(self.critic_2.state_dict(), directory+'critic_2.pth')
torch.save(self.critic_2_target.state_dict(), directory+'critic_2_target.pth')
print("====================================")
print("Model has been saved...")
print("====================================")
def load(self):
self.actor.load_state_dict(torch.load(directory + 'actor.pth'))
self.actor_target.load_state_dict(torch.load(directory + 'actor_target.pth'))
self.critic_1.load_state_dict(torch.load(directory + 'critic_1.pth'))
self.critic_1_target.load_state_dict(torch.load(directory + 'critic_1_target.pth'))
self.critic_2.load_state_dict(torch.load(directory + 'critic_2.pth'))
self.critic_2_target.load_state_dict(torch.load(directory + 'critic_2_target.pth'))
print("====================================")
print("model has been loaded...")
print("====================================")
def main():
agent = TD3(state_dim, action_dim, max_action)
ep_r = 0
if args.mode == 'test':
agent.load()
for epoch in range(args.test_episode): # 50
state = env.reset()
for t in count():
action = agent.select_action(state)
next_state, reward, done, info = env.step(np.float32(action))
ep_r += reward
env.render()
if done or t == args.max_frame -1:
print("Ep_i \t{}, the ep_r is \t{:0.2f}, the step is \t{}".format(epoch, ep_r, t))
break
state = next_state
elif args.mode == 'train':
print("====================================")
print("Collection Experience...")
print("====================================")
if args.load: agent.load()
for epoch in range(args.num_episode):
state = env.reset()
for t in range(args.max_frame):
action = agent.select_action(state)
action = action + np.random.normal(0, args.exploration_noise, size=env.action_space.shape[0])
action = action.clip(env.action_space.low, env.action_space.high)
next_state, reward, done, info = env.step(action)
ep_r += reward
if args.render and epoch >= args.render_interval:
env.render()
agent.memory.push((state, next_state, action, reward, np.float(done)))
state = next_state
if len(agent.memory.storage) >= args.capacity-1:
agent.update(args.epoch)
if done or t == args.max_frame -1:
agent.writer.add_scalar('ep_r', ep_r, global_step=epoch)
if epoch % args.print_log == 0:
print("Ep_i {}, the ep_r is {:0.2f}, the step is {}".format(epoch, ep_r, t))
ep_r = 0
break
if epoch % args.log_interval == 0:
agent.save()
else:
raise NameError("mode wrong!!!")
if __name__ == '__main__':
main()