DRL实战:SAC | Gym环境中经典控制问题Pendulum-v0及调参记录

本文介绍了一种强化学习算法——软 Actor-Critic (SAC) 的实现过程,特别针对连续动作空间的问题,如倒立摆控制任务。文中详细展示了 SAC 算法的原理及其在 Pendulum-v0 环境中的应用。
摘要由CSDN通过智能技术生成

1 Pendulum-v0介绍

在这里插入图片描述

倒立摆问题是控制文献中的经典问题。 在这个版本的问题中,钟摆以随机位置开始,目标是将其向上摆动,使其保持直立。
动作空间,状态空间都为连续的,其中env.observation_space.shape[0] = 3, env.action_space.shape[0] = 1
详情可参考链接:https://www.jianshu.com/p/af3a7853268f

2 代码

该代码参考伯禹学习平台对SAC算法原理的讲解与SAC实战的示例:https://hrl.boyuai.com/chapter/2/sac%E7%AE%97%E6%B3%95

import random
import gym
import numpy as np

import torch
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt


class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.num_transition = 0

    def push(self, state, action, reward, done, next_state):
        if self.num_transition < self.capacity:
            self.buffer.append((state, action, reward, done, next_state))
        else:
            index = self.num_transition % self.capacity
            self.buffer[index] = (state, action, reward, done, next_state)
        self.num_transition += 1

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, done, next_state = map(np.stack, zip(*batch))
        return state, action, reward, done, next_state

    def __len__(self):
        return len(self.buffer)


class PolicyNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound):
        super(PolicyNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_std = torch.nn.Linear(hidden_dim, action_dim)
        self.action_bound = action_bound

    def forward(self, x):
        x = F.relu(self.fc1(x))
        mu = self.fc_mu(x)
        std = F.softplus(self.fc_std(x))   # softplus是relu函数的平滑版本
        dist = Normal(mu, std)
        normal_sample = dist.rsample()  # rsample()是重参数化采样
        # By using rsample, we can backpropagate. rsample: sampling using reparameterization trick.
        log_prob = dist.log_prob(normal_sample)
        action = torch.tanh(normal_sample)    # action的值被限定在[-1,1]
        # 计算tanh_normal分布的对数概率密度
        log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)
        action = action * self.action_bound  # action的范围
        return action, log_prob


class QValueNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(QValueNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc_out = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x, a):
        cat = torch.cat([x, a], dim=1)
        x = F.relu(self.fc1(cat))
        x = F.relu(self.fc2(x))
        return self.fc_out(x)


class SACContinuous:
    ''' 处理连续动作的SAC算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,
                 device):
        self.replay_buffer = ReplayBuffer(replay_buffer_size)
        self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim,
                                         action_bound).to(device)  # 策略网络
        self.critic_1 = QValueNetContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 第一个Q网络
        self.critic_2 = QValueNetContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 第二个Q网络
        self.target_critic_1 = QValueNetContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 第一个目标Q网络
        self.target_critic_2 = QValueNetContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 第二个目标Q网络
        # 令目标Q网络的初始参数和Q网络一样
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),
                                                   lr=critic_lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),
                                                   lr=critic_lr)
        # 使用alpha的log值,可以使训练结果比较稳定
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        self.log_alpha.requires_grad = True  # 可以对alpha求梯度
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr)
        self.target_entropy = target_entropy  # 目标熵的大小
        self.gamma = gamma
        self.tau = tau
        self.device = device

    def take_action(self, state):
        #print('state ', state)
        state = torch.tensor(np.array(state), dtype=torch.float).to(self.device)
        action = self.actor(state)[0]  # tensor型
        #print('action ', action)
        return [action.item()]

    def calc_target(self, rewards, next_states, dones):  # 计算目标Q值
        next_actions, log_prob = self.actor(next_states)
        entropy = -log_prob   #torch.Size([16, 1])
        #print('dones ', dones.shape)
        #print('next_states ', next_states.shape)
        #print('rewards ', rewards.shape)
        q1_value = self.target_critic_1(next_states, next_actions)
        q2_value = self.target_critic_2(next_states, next_actions)
        #print('self.log_alpha.exp() ', self.log_alpha.exp())
        next_value = torch.min(q1_value,
                               q2_value) + self.log_alpha.exp() * entropy
        #print('next_value ', next_value.shape)   #torch.Size([16, 1])
        td_target = rewards + self.gamma * next_value * (1 - dones)
        #print('next_value * (1 - dones) ', (next_value * (1 - dones)).shape)
        return td_target

    def soft_update(self, net, target_net):
        for param_target, param in zip(target_net.parameters(),
                                       net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) +
                                    param.data * self.tau)

    def update(self):
        states, actions, rewards, dones, next_states = self.replay_buffer.sample(batch_size)

        states = torch.tensor(states,
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(actions,
                               dtype=torch.float).to(self.device)
        rewards = torch.tensor(rewards,
                               dtype=torch.float).unsqueeze(1).to(self.device)   # 增加维度,便于参与计算
        next_states = torch.tensor(next_states,
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(dones,
                             dtype=torch.float).unsqueeze(1).to(self.device)   # 增加维度,便于参与计算
                             
        # 和之前章节一样,对倒立摆环境的奖励进行重塑以便训练
        # rewards = (rewards + 8.0) / 8.0

        # 更新两个Q网络
        td_target = self.calc_target(rewards, next_states, dones)
        #print('td_target ', td_target.shape)   #torch.Size([16, 16])
        value1 = self.critic_1(states, actions)
        #print('value1 ', value1.shape)  #torch.Size([16, 1])
        critic_1_loss = torch.mean(
            F.mse_loss(self.critic_1(states, actions), td_target.detach()))
        critic_2_loss = torch.mean(
            F.mse_loss(self.critic_2(states, actions), td_target.detach()))
        self.critic_1_optimizer.zero_grad()
        critic_1_loss.backward()
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.zero_grad()
        critic_2_loss.backward()
        self.critic_2_optimizer.step()

        # 更新策略网络
        new_actions, log_prob = self.actor(states)
        entropy = -log_prob
        q1_value = self.critic_1(states, new_actions)
        q2_value = self.critic_2(states, new_actions)
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy -
                                torch.min(q1_value, q2_value))
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 更新alpha值
        alpha_loss = torch.mean(
            (entropy - self.target_entropy).detach() * self.log_alpha.exp())
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)


if __name__ == '__main__':
    env_name = 'Pendulum-v0'
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    action_bound = env.action_space.high[0]  # 动作最大值 2

    # 固定随机数
    random.seed(0)
    np.random.seed(0)
    env.seed(0)
    torch.manual_seed(0)

    actor_lr = 3e-4
    critic_lr = 3e-3
    alpha_lr = 3e-4
    epoch = 100
    hidden_dim = 128
    gamma = 0.99
    tau = 0.005  # 软更新参数
    replay_buffer_size = 2000    # minimal_size = 1000
    batch_size = 16
    target_entropy = -env.action_space.shape[0]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,
                          actor_lr, critic_lr, alpha_lr, target_entropy, tau,
                          gamma, device)

    reward_his = []
    for e in range(epoch):
        reward_epoch = 0
        state = env.reset()
        done = False
        step = 0
        while not done:
            env.render()
            action = agent.take_action(state)
            next_state, reward, done, _ = env.step(action)
            # print('reward ', reward)
            reward_epoch += reward
            agent.replay_buffer.push(state, action, reward, done, next_state)
            if agent.replay_buffer.num_transition > agent.replay_buffer.capacity:  # and dqn.memory_counter%5==0
                agent.update()
            state = next_state
            step += 1
        reward_his.append(reward_epoch)
        if e % 10 == 0:
            print('epoch ', e, 'reward_epoch ', reward_epoch, 'step ', step)

    plt.plot(reward_his, "r-", label="reward_his")
    plt.legend()
    plt.show()

3 调参

3.1初步训练

在这里插入图片描述

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值