深度强化学习-TD3算法

论文地址:https://arxiv.org/pdf/1802.09477.pdf

        TD3(Twin Delayed Deep Deterministic policy gradient algorithm)算法适合于高维连续动作空间,是DDPG算法的优化版本,为了优化DDPG在训练过程中Q值估计过高的问题。

相较DDPG的改进:

1、运用两个Critic网络。运用两个网络对动作价值函数进行估计。在练习的时分挑选最小的Q值作为估值(为了防止误差累积过高)。

2、运用延迟学习。Critic网络更新的频率要比Actor网络更新的频率要大(类似GAN的思想,先训练好Critic才能更好的对actor指指点点)。

3、运用梯度截取。将Actor的参数更新的梯度截取到某个范围内。

4、加入训练噪声。更新Critic网络时候加入随机噪声,以达到对Critic网络波动的稳定性。

算法流程:

        算法的伪代码 

代码实现: 

        actor:

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, net_width, maxaction):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, net_width)
        self.l2 = nn.Linear(net_width, net_width)
        self.l3 = nn.Linear(net_width, action_dim)

        self.maxaction = maxaction

    def forward(self, state):
        a = torch.tanh(self.l1(state))
        a = torch.tanh(self.l2(a))
        a = torch.tanh(self.l3(a)) * self.maxaction
        return a

         critic:

class Q_Critic(nn.Module):
    def __init__(self, state_dim, action_dim, net_width):
        super(Q_Critic, self).__init__()

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, net_width)
        self.l2 = nn.Linear(net_width, net_width)
        self.l3 = nn.Linear(net_width, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, net_width)
        self.l5 = nn.Linear(net_width, net_width)
        self.l6 = nn.Linear(net_width, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1

         TD3的整体实现:

class TD3(object):
    def __init__(
            self,
            env_with_Dead,
            state_dim,
            action_dim,
            max_action,
            gamma=0.99,
            net_width=128,
            a_lr=1e-4,
            c_lr=1e-4,
            Q_batchsize=256
    ):

        self.actor = Actor(state_dim, action_dim, net_width, max_action).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=a_lr)
        self.actor_target = copy.deepcopy(self.actor)

        self.q_critic = Q_Critic(state_dim, action_dim, net_width).to(device)
        self.q_critic_optimizer = torch.optim.Adam(self.q_critic.parameters(), lr=c_lr)
        self.q_critic_target = copy.deepcopy(self.q_critic)

        self.env_with_Dead = env_with_Dead
        self.action_dim = action_dim
        self.max_action = max_action
        self.gamma = gamma
        self.policy_noise = 0.2 * max_action
        self.noise_clip = 0.5 * max_action
        self.tau = 0.005
        self.Q_batchsize = Q_batchsize
        self.delay_counter = -1
        self.delay_freq = 1

    def select_action(self, state):  # only used when interact with the env
        with torch.no_grad():
            state = torch.FloatTensor(state.reshape(1, -1)).to(device)
            a = self.actor(state)
        return a.cpu().numpy().flatten()

    def train(self, replay_buffer):
        self.delay_counter += 1
        with torch.no_grad():
            s, a, r, s_prime, dead_mask = replay_buffer.sample(self.Q_batchsize)
            noise = (torch.randn_like(a) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            smoothed_target_a = (
                    self.actor_target(s_prime) + noise  # Noisy on target action
            ).clamp(-self.max_action, self.max_action)

        # Compute the target Q value
        target_Q1, target_Q2 = self.q_critic_target(s_prime, smoothed_target_a)
        target_Q = torch.min(target_Q1, target_Q2)
        '''DEAD OR NOT'''
        if self.env_with_Dead:
            target_Q = r + (1 - dead_mask) * self.gamma * target_Q  # env with dead
        else:
            target_Q = r + self.gamma * target_Q  # env without dead

        # Get current Q estimates
        current_Q1, current_Q2 = self.q_critic(s, a)

        # Compute critic loss
        q_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        # Optimize the q_critic
        self.q_critic_optimizer.zero_grad()
        q_loss.backward()
        self.q_critic_optimizer.step()

        if self.delay_counter == self.delay_freq:
            # Update Actor
            a_loss = -self.q_critic.Q1(s, self.actor(s)).mean()
            self.actor_optimizer.zero_grad()
            a_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            self.delay_counter = -1

    def save(self, episode):
        torch.save(self.actor.state_dict(), "ppo_actor{}.pth".format(episode))
        torch.save(self.q_critic.state_dict(), "ppo_q_critic{}.pth".format(episode))

    def load(self, episode):

        self.actor.load_state_dict(torch.load("ppo_actor{}.pth".format(episode)))
        self.q_critic.load_state_dict(torch.load("ppo_q_critic{}.pth".format(episode)))

 网络结构图:

         其中actor和target部分的网络参数会延迟更新,也就是说critic1和critic2参数在不断更新,训练好critic之后才能知道actor做出理想的动作。

 

TD3算法和DDPG算法的比较优缺点如下: 优点: 1. TD3算法相对于DDPG算法来说更加稳定,能够更快地收敛。 2. TD3算法引入了目标策略平滑正则化,可以减少过拟合的情况。 3. TD3算法在训练过程中使用了三个神经网络,可以更好地估计Q值函数。 缺点: 1. TD3算法相对于DDPG算法来说更加复杂,需要更多的计算资源。 2. TD3算法在某些情况下可能会出现低估Q值的情况。 3. TD3算法对于超参数的选择比较敏感,需要进行更加细致的调参。 下面是一个使用TD3算法解决连续控制问题的例子: ```python import gym import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.autograd import Variable # 定义Actor网络 class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super(Actor, self).__init__() self.layer1 = nn.Linear(state_dim,400) self.layer2 = nn.Linear(400, 300) self.layer3 = nn.Linear(300, action_dim) self.max_action = max_action def forward(self, state): x = F.relu(self.layer1(state)) x = F.relu(self.layer2(x)) x = self.max_action * torch.tanh(self.layer3(x)) return x # 定义Critic网络 class Critic(nn.Module): def __init__(self, state_dim, action_dim): super(Critic, self).__init__() self.layer1 = nn.Linear(state_dim + action_dim, 400) self.layer2 = nn.Linear(400, 300) self.layer3 = nn.Linear(300, 1) def forward(self, state, action): x = torch.cat([state, action], 1) x = F.relu(self.layer1(x)) x = F.relu(self.layer2(x)) x = self.layer3(x) return x # 定义TD3算法 class TD3(object): def __init__(self, state_dim, action_dim, max_action): self.actor = Actor(state_dim, action_dim, max_action) self.actor_target = Actor(state_dim, action_dim, max_action) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=0.001) self.critic1 = Critic(state_dim, action_dim) self.critic1_target = Critic(state_dim, action_dim) self.critic1_target.load_state_dict(self.critic1.state_dict()) self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=0.001) self.critic2 = Critic(state_dim, action_dim) self.critic2_target = Critic(state_dim, action_dim) self.critic2_target.load_state_dict(self.critic2.state_dict()) self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=0.001) self.max_action = max_action def select_action(self, state): state = torch.FloatTensor(state.reshape(1, -1)) return self.actor(state).cpu().data.numpy().flatten() def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2): for it in range(iterations): # 从缓存中随机采样一批数据 batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample(batch_size) state = torch.FloatTensor(batch_states) next_state = torch.FloatTensor(batch_next_states) action = torch.FloatTensor(batch_actions) reward = torch.FloatTensor(batch_rewards.reshape((batch_size, 1))) done = torch.FloatTensor(batch_dones.reshape((batch_size, 1))) # 计算目标Q值 with torch.no_grad(): noise = (torch.randn_like(action) * policy_noise).clamp(-noise_clip, noise_clip) next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action) target_Q1 = self.critic1_target(next_state, next_action) target_Q2 = self.critic2_target(next_state, next_action) target_Q = torch.min(target_Q1, target_Q2) target_Q = reward + ((1 - done) * discount * target_Q) # 更新Critic1网络 current_Q1 = self.critic1(state, action) loss_Q1 = F.mse_loss(current_Q1, target_Q) self.critic1_optimizer.zero_grad() loss_Q1.backward() self.critic1_optimizer.step() # 更新Critic2网络 current_Q2 = self.critic2(state, action) loss_Q2 = F.mse_loss(current_Q2, target_Q) self.critic2_optimizer.zero_grad() loss_Q2.backward() self.critic2_optimizer.step() # 延迟更新Actor网络和目标网络 if it % policy_freq == 0: # 更新Actor网络 actor_loss = -self.critic1(state, self.actor(state)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # 更新目标网络 for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) def save(self, filename): torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.critic1.state_dict(), filename + "_critic1") torch.save(self.critic2.state_dict(), filename + "_critic2") def load(self, filename): self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_target.load_state_dict(torch.load(filename + "_actor")) self.critic1.load_state_dict(torch.load(filename + "_critic1")) self.critic1_target.load_state_dict(torch.load(filename + "_critic1")) self.critic2.load_state_dict(torch.load(filename + "_critic2")) self.critic2_target.load_state_dict(torch.load(filename + "_critic2")) # 创建环境 env = gym.make('Pendulum-v0') state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] max_action = float(env.action_space.high[0]) # 创建TD3算法对象 td3 = TD3(state_dim, action_dim, max_action) # 定义缓存大小和训练次数 replay_buffer = ReplayBuffer() replay_buffer_size = 1000000 replay_buffer.init(replay_buffer_size, state_dim, action_dim) iterations = 100000 # 训练TD3算法 state, done = env.reset(), False episode_reward = 0 episode_timesteps = 0 episode_num = 0 for t in range(iterations): episode_timesteps += 1 # 选择动作并执行 action = td3.select_action(state) next_state, reward, done, _ = env.step(action) replay_buffer.add(state, next_state, action, reward, done) state = next_state episode_reward += reward # 如果缓存中的数据足够,就开始训练 if replay_buffer.size() > 1000: td3.train(replay_buffer, 100) # 如果一个episode结束,就输出信息 if done: print("Total Timesteps: {} Episode Num: {} Episode Timesteps: {} Reward: {}".format(t+1, episode_num+1, episode_timesteps, episode_reward)) state, done = env.reset(), False episode_reward = 0 episode_timesteps = 0 episode_num += 1 # 保存模型 td3.save("td3_pendulum") --相关问题--:
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

athrunsunny

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值