在使用训练好的模型对gym中的CartPole-v1时,写了下面的代码
obs = env.reset()
action,_ = model.predict(obs)
报了个大错,贼长,如下:
ValueError Traceback (most recent call last)
Cell In[43], line 1
----> 1 action,_ = model.predict(obs)
File c:\users\dell\appdata\local\programs\python\python39\lib\site-packages\stable_baselines3\common\base_class.py:556, in BaseAlgorithm.predict(self, observation, state, episode_start, deterministic)
536 def predict(
537 self,
538 observation: Union[np.ndarray, Dict[str, np.ndarray]],
(...)
541 deterministic: bool = False,
542 ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
543 """
544 Get the policy action from an observation (and optional hidden state).
545 Includes sugar-coating to handle different observations (e.g. normalizing images).
(...)
554 (used in recurrent policies)
555 """
--> 556 return self.policy.predict(observation, state, episode_start, deterministic)
File c:\users\dell\appdata\local\programs\python\python39\lib\site-packages\stable_baselines3\common\policies.py:357, in BasePolicy.predict(self, observation, state, episode_start, deterministic)
354 # Check for common mistake that the user does not mix Gym/VecEnv API
355 # Tuple obs are not supported by SB3, so we can safely do that check
356 if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
--> 357 raise ValueError(
358 "You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
359 "You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
360 "vs `obs = vec_env.reset()` (SB3 VecEnv). "
361 "See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
362 "and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
363 )
365 obs_tensor, vectorized_env = self.obs_to_tensor(observation)
367 with th.no_grad():
ValueError: You have passed a tuple to the predict() function instead of a Numpy array or a Dict. You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) vs `obs = vec_env.reset()` (SB3 VecEnv). See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api
总结一下就是两个API混着用,需要输入列表结果输入的是元组,所以报错了。
找遍了全网没找到怎么改,问了gpt也没用。
改了好久突然发现obs, info = env.reset()
这句代码,于是去看了下源文件中的定义(https://www.gymlibrary.dev/api/core/#gym.Env.reset)
observation (object)
– Observation of the initial state. This will be an element of observation_space (typically a numpy array) and is analogous to the observation returned by step().
info (dictionary)
– This dictionary contains auxiliary information complementing observation. It should be analogous to the info returned by step().
返回的还真不是列表,是由一个numpy array(obs)和一个dictionary(info)组成的元组。
这就好办了,用两个变量分别接收就好了,修改如下
obs,info = env.reset()
action, _ = model.predict(obs)
用type函数看看数据类型
诶对了,然后再放进去跑就行了
更新2024.5.10--------------------------------------------------------------------------------------
原来是之前
env = gym.make(environment_name, render_mode="human")
env = DummyVecEnv([lambda:env])
model = PPO('MlpPolicy', env, verbose = 1,tensorboard_log=log_path)
第二行代码没加,导致冲突了,加了第二行代码可以不用上述方法