stable_baselines3学习笔记

之前的一篇博客是 强化学习自动泊车的问题_highwayenv车一直转圈圈-CSDN博客

原以为训练效果比较好的时候,可以写一点总结,但惭愧的是,训练效果始终不如人意,这里还是先写一点总结,训练的事情后面再说

我将工程中highway_env/envs/common/observation.py的

    obs = OrderedDict(
            [
                ("observation", obs / self.scales),
                ("achieved_goal", obs / self.scales),
                ("desired_goal", goal / self.scales),
            ]
        )

改为了

    obs = OrderedDict(
            [
                # ("observation", obs / self.scales),
                ("achieved_goal", obs / self.scales),
                ("desired_goal", goal / self.scales),
                ("render_img", self.env.render()),
            ]
        )

这样我就多了一个render_img,他的shape为 (180, 360, 3),因为我对尺寸也做了修改,这样我的model的输入,就有了vector,即 achieved_goal 和 desired_goal,也有了img,即 render_img,所以model的结构有全连接,也有CNN

打印model.policy得出如下:

MultiInputPolicy(
  (actor): Actor(
    (features_extractor): CombinedExtractor(
      (extractors): ModuleDict(
        (achieved_goal): Flatten(start_dim=1, end_dim=-1)
        (desired_goal): Flatten(start_dim=1, end_dim=-1)
        (render_img): NatureCNN(
          (cnn): Sequential(
            (0): Conv2d(3, 32, kernel_size=(8, 8), stride=(4, 4))
            (1): ReLU()
            (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
            (3): ReLU()
            (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
            (5): ReLU()
            (6): Flatten(start_dim=1, end_dim=-1)
          )
          (linear): Sequential(
            (0): Linear(in_features=49856, out_features=256, bias=True)
            (1): ReLU()
          )
        )
      )
    )
    (latent_pi): Sequential(
      (0): Linear(in_features=268, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=512, bias=True)
      (5): ReLU()
    )
    (mu): Linear(in_features=512, out_features=2, bias=True)
    (log_std): Linear(in_features=512, out_features=2, bias=True)
  )
  (critic): ContinuousCritic(
    (features_extractor): CombinedExtractor(
      (extractors): ModuleDict(
        (achieved_goal): Flatten(start_dim=1, end_dim=-1)
        (desired_goal): Flatten(start_dim=1, end_dim=-1)
        (render_img): NatureCNN(
          (cnn): Sequential(
            (0): Conv2d(3, 32, kernel_size=(8, 8), stride=(4, 4))
            (1): ReLU()
            (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
            (3): ReLU()
            (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
            (5): ReLU()
            (6): Flatten(start_dim=1, end_dim=-1)
          )
          (linear): Sequential(
            (0): Linear(in_features=49856, out_features=256, bias=True)
            (1): ReLU()
          )
        )
      )
    )
    (qf0): Sequential(
      (0): Linear(in_features=270, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=512, bias=True)
      (5): ReLU()
      (6): Linear(in_features=512, out_features=1, bias=True)
    )
    (qf1): Sequential(
      (0): Linear(in_features=270, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=512, bias=True)
      (5): ReLU()
      (6): Linear(in_features=512, out_features=1, bias=True)
    )
  )
  (critic_target): ContinuousCritic(
    (features_extractor): CombinedExtractor(
      (extractors): ModuleDict(
        (achieved_goal): Flatten(start_dim=1, end_dim=-1)
        (desired_goal): Flatten(start_dim=1, end_dim=-1)
        (render_img): NatureCNN(
          (cnn): Sequential(
            (0): Conv2d(3, 32, kernel_size=(8, 8), stride=(4, 4))
            (1): ReLU()
            (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
            (3): ReLU()
            (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
            (5): ReLU()
            (6): Flatten(start_dim=1, end_dim=-1)
          )
          (linear): Sequential(
            (0): Linear(in_features=49856, out_features=256, bias=True)
            (1): ReLU()
          )
        )
      )
    )
    (qf0): Sequential(
      (0): Linear(in_features=270, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=512, bias=True)
      (5): ReLU()
      (6): Linear(in_features=512, out_features=1, bias=True)
    )
    (qf1): Sequential(
      (0): Linear(in_features=270, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=512, bias=True)
      (5): ReLU()
      (6): Linear(in_features=512, out_features=1, bias=True)
    )
  )
)

这就是基本的网络结构,对网络结构熟悉后可以再进行定制

上面有一层是 

(0): Linear(in_features=49856, out_features=256, bias=True)

我是觉得这样一层全连接太大了,所以想定制一下CNN

修改代码如下:

import gymnasium as gym
import highway_env

# Agent
from stable_baselines3 import HerReplayBuffer, SAC

import sys
from tqdm.notebook import trange
from utils import record_videos
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch as th
from torch import nn
# from stable_baselines3.common.torch_layers import NatureCNN
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space

LEARNING_STEPS = 2e4 # @param {type: "number"}


class ExponentialCNN(BaseFeaturesExtractor):
    """
    CNN from DQN Nature paper:
        Mnih, Volodymyr, et al.
        "Human-level control through deep reinforcement learning."
        Nature 518.7540 (2015): 529-533.

    :param observation_space:
    :param features_dim: Number of features extracted.
        This corresponds to the number of unit for the last layer.
    :param normalized_image: Whether to assume that the image is already normalized
        or not (this disables dtype and bounds checks): when True, it only checks that
        the space is a Box and has 3 dimensions.
        Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
    """

    def __init__(
        self,
        observation_space: gym.Space,
        features_dim: int = 512,
        normalized_image: bool = False,
    ) -> None:
        assert isinstance(observation_space, spaces.Box), (
            "ExponentialCNN must be used with a gym.spaces.Box ",
            f"observation space, not {observation_space}",
        )
        super().__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), (
            "You should use ExponentialCNN "
            f"only with images not with {observation_space}\n"
            "(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
            "If you are using a custom environment,\n"
            "please check it using our env checker:\n"
            "https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html.\n"
            "If you are using `VecNormalize` or already normalized channel-first images "
            "you should pass `normalize_images=False`: \n"
            "https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html"
        )
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 16, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=0),
            # nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=0),
            # nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=0),
            nn.Flatten(),
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.cnn(observations)

class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict):
        # We do not know features-dim here before going over all the items,
        # so put something dummy for now. PyTorch requires calling
        # nn.Module.__init__ before adding modules
        super().__init__(observation_space, features_dim=1)

        extractors = {}

        total_concat_size = 0
        full_connect_extractor = nn.Sequential(
            nn.Linear(in_features=6, out_features=512, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=512, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=512, bias=True),
            nn.ReLU()
        )
        # We need to know size of the output of this extractor,
        # so go over all the spaces and compute output feature sizes
        for key, subspace in observation_space.spaces.items():
            if key == "render_img":
                # We will just downsample one channel of the image by 4x4 and flatten.
                # Assume the image is single-channel (subspace.shape[0] == 0)
                extractors[key] = ExponentialCNN(observation_space=observation_space["render_img"])
                total_concat_size += 2048
            else:
                # Run through a simple MLP
                extractors[key] = full_connect_extractor
                total_concat_size += 512

        self.extractors = nn.ModuleDict(extractors)

        # Update the features dim manually
        self._features_dim = total_concat_size

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        # self.extractors contain nn.Modules that do all the processing.
        for key, extractor in self.extractors.items():
            encoded_tensor_list.append(extractor(observations[key]))
        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return th.cat(encoded_tensor_list, dim=1)


env = gym.make('parking-v0', render_mode='rgb_array')
her_kwargs = dict(n_sampled_goal=4, goal_selection_strategy='future')
model = SAC('MultiInputPolicy', env, replay_buffer_class=HerReplayBuffer,
            replay_buffer_kwargs=her_kwargs, verbose=1,
            tensorboard_log="logs",
            buffer_size=int(4e5),
            learning_rate=1e-3,
            learning_starts=1e3,
            gamma=0.95, batch_size=1024, tau=0.05,
            policy_kwargs=dict(features_extractor_class=CustomCombinedExtractor, net_arch=[512, 512]))

# model = SAC.load("/home/zhuqinghua/workspace/HighwayEnv/scripts/parking_her/model1", env = env)
# model.learn(int(LEARNING_STEPS))
# model.save("parking_her/model")

N_EPISODES = 10  # @param {type: "integer"}

env = gym.make('parking-v0', render_mode='rgb_array')
env = record_videos(env)
for episode in trange(N_EPISODES, desc="Test episodes"):
    obs, info = env.reset()
    done = truncated = False
    while not (done or truncated):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
env.close()

现在再看model.policy,就变成了

MultiInputPolicy(
  (actor): Actor(
    (features_extractor): CustomCombinedExtractor(
      (extractors): ModuleDict(
        (achieved_goal): Sequential(
          (0): Linear(in_features=6, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=512, bias=True)
          (3): ReLU()
          (4): Linear(in_features=512, out_features=512, bias=True)
          (5): ReLU()
        )
        (desired_goal): Sequential(
          (0): Linear(in_features=6, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=512, bias=True)
          (3): ReLU()
          (4): Linear(in_features=512, out_features=512, bias=True)
          (5): ReLU()
        )
        (render_img): ExponentialCNN(
          (cnn): Sequential(
            (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2))
            (1): ReLU()
            (2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
            (3): ReLU()
            (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
            (5): ReLU()
            (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
            (7): ReLU()
            (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
            (9): ReLU()
            (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2))
            (11): Flatten(start_dim=1, end_dim=-1)
          )
        )
      )
    )
    (latent_pi): Sequential(
      (0): Linear(in_features=3072, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
    )
    (mu): Linear(in_features=512, out_features=2, bias=True)
    (log_std): Linear(in_features=512, out_features=2, bias=True)
  )
  (critic): ContinuousCritic(
    (features_extractor): CustomCombinedExtractor(
      (extractors): ModuleDict(
        (achieved_goal): Sequential(
          (0): Linear(in_features=6, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=512, bias=True)
          (3): ReLU()
          (4): Linear(in_features=512, out_features=512, bias=True)
          (5): ReLU()
        )
        (desired_goal): Sequential(
          (0): Linear(in_features=6, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=512, bias=True)
          (3): ReLU()
          (4): Linear(in_features=512, out_features=512, bias=True)
          (5): ReLU()
        )
        (render_img): ExponentialCNN(
          (cnn): Sequential(
            (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2))
            (1): ReLU()
            (2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
            (3): ReLU()
            (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
            (5): ReLU()
            (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
            (7): ReLU()
            (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
            (9): ReLU()
            (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2))
            (11): Flatten(start_dim=1, end_dim=-1)
          )
        )
      )
    )
    (qf0): Sequential(
      (0): Linear(in_features=3074, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=1, bias=True)
    )
    (qf1): Sequential(
      (0): Linear(in_features=3074, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=1, bias=True)
    )
  )
  (critic_target): ContinuousCritic(
    (features_extractor): CustomCombinedExtractor(
      (extractors): ModuleDict(
        (achieved_goal): Sequential(
          (0): Linear(in_features=6, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=512, bias=True)
          (3): ReLU()
          (4): Linear(in_features=512, out_features=512, bias=True)
          (5): ReLU()
        )
        (desired_goal): Sequential(
          (0): Linear(in_features=6, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=512, bias=True)
          (3): ReLU()
          (4): Linear(in_features=512, out_features=512, bias=True)
          (5): ReLU()
        )
        (render_img): ExponentialCNN(
          (cnn): Sequential(
            (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2))
            (1): ReLU()
            (2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
            (3): ReLU()
            (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
            (5): ReLU()
            (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
            (7): ReLU()
            (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
            (9): ReLU()
            (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2))
            (11): Flatten(start_dim=1, end_dim=-1)
          )
        )
      )
    )
    (qf0): Sequential(
      (0): Linear(in_features=3074, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=1, bias=True)
    )
    (qf1): Sequential(
      (0): Linear(in_features=3074, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=1, bias=True)
    )
  )
)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值