在使用混合输入(图像 + 向量)的时候,发现似乎一直没有给用卷积网络训练,这部分核心是CombinedExtractor混合特征提取,源码如下:
class CombinedExtractor(BaseFeaturesExtractor):
"""
Combined features extractor for Dict observation spaces.
Builds a features extractor for each key of the space. Input from each space
is fed through a separate submodule (CNN or MLP, depending on input shape),
the output features are concatenated and fed through additional MLP network ("combined").
:param observation_space:
:param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
256 to avoid exploding network sizes.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""
def __init__(
self,
observation_space: spaces.Dict,
cnn_output_dim: int = 256,
normalized_image: bool = False,
) -> None:
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)
extractors: Dict[str, nn.Module] = {}
total_concat_size = 0
for key, subspace in observation_space.spaces.items():
if is_image_space(subspace, normalized_image=normalized_image):
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image)
total_concat_size += cnn_output_dim
else:
# The observation key is a vector, flatten it if needed
extractors[key] = nn.Flatten()
total_concat_size += get_flattened_obs_dim(subspace)
self.extractors = nn.ModuleDict(extractors)
# Update the features dim manually
self._features_dim = total_concat_size
def forward(self, observations: TensorDict) -> th.Tensor:
encoded_tensor_list = []
for key, extractor in self.extractors.items():
encoded_tensor_list.append(extractor(observations[key]))
return th.cat(encoded_tensor_list, dim=1)
发现图像要用卷积,必须得过 is_image_space 检查,再看 is_image_space 的源码,发现这检测条件有点严格,如数据类型得是 np.uint8,或者就得设置 normalized_image 等等,所以使用上需要多注意,否则训练半天模型都不对.......
def is_image_space(
observation_space: spaces.Space,
check_channels: bool = False,
normalized_image: bool = False,
) -> bool:
"""
Check if a observation space has the shape, limits and dtype
of a valid image.
The check is conservative, so that it returns False if there is a doubt.
Valid images: RGB, RGBD, GrayScale with values in [0, 255]
:param observation_space:
:param check_channels: Whether to do or not the check for the number of channels.
e.g., with frame-stacking, the observation space may have more channels than expected.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
:return:
"""
check_dtype = check_bounds = not normalized_image
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
# Check the type
if check_dtype and observation_space.dtype != np.uint8:
return False
# Check the value range
incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255)
if check_bounds and incorrect_bounds:
return False
# Skip channels check
if not check_channels:
return True
# Check the number of channels
if is_image_space_channels_first(observation_space):
n_channels = observation_space.shape[0]
else:
n_channels = observation_space.shape[-1]
# GrayScale, RGB, RGBD
return n_channels in [1, 3, 4]
return False
再看一下自己状态空间的定义:
# set up spaces
self.single_observation_space = gym.spaces.Dict()
self.single_observation_space["policy"] = gym.spaces.Dict()
self.single_observation_space["policy"]["img"] = gym.spaces.Box(low=0, high=255, shape=(self.cfg.tiled_camera.height, self.cfg.tiled_camera.width, self.cfg.num_channels))
self.single_observation_space["policy"]["vec"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_observations_vec,))
打印出来是这样的
Dict('img': Box(0.0, 255.0, (48, 64, 1), float32), 'vec': Box(-inf, inf, (13,), float32))
又没有设置 normalized_image ,对于 GrayScale, RGB, RGBD 还要注意 check_channels