在指定环境下训练模型的过程中,当遇到输入或者从RL模型中返回的NaN或inf时,RL模型有完全崩溃的可能。
-
原因和方式
问题出现后,NaNs和infs不会崩溃,而是简单的通过训练传递,直到所有的浮点数收敛到NaN或inf。这符合IEEE浮点运算标准(IEEE754),标准指出:
可能出现的物种异常:
- 无效的操作符( − 1 \sqrt{-1} −1, inf$*$1, NaN mod 1, …)返回NaN
- 除以0:
- 如果运算对象非零(1/0, -2/0, …)返回 ± i n f \pm inf ±inf
- 如果运算对象是零(0/0)返回NaN
- 上溢(指数太高而无法表示)返回 ± i n f \pm inf ±inf
- 下溢(指数太低而无法表示)返回 0 0 0
- 不精确(以2为底时不能准确表示,例如1/5)返回四舍五入值(例如:
assert (1/5) * 3 == 0.6000000000000001
)
只有除以0会报错,其他方式只会静静传递。
在Python中,除以0会报如下错:
ZeroDivisionError: float division by zero
,其他会忽略。Numpy中默认警告:
RuntimeWarning: invalid value encountered
但不会停止代码。最差的情况,Tensorflow不会提示任何信息
import tensorflow as tf import numpy as np print("tensorflow test:") a = tf.constant(1.0) b = tf.constant(0.0) c = a / b sess = tf.Session() val = sess.run(c) # this will be quiet print(val) sess.close() print("\r\nnumpy test:") a = np.float64(1.0) b = np.float64(0.0) val = a / b # this will warn print(val) print("\r\npure python test:") a = 1.0 b = 0.0 val = a / b # this will raise an exception and halt. print(val)
不幸的是,大多数浮点运算都是用Tensorflow和Numpy处理的,这意味着当无效值出现时,你很可能得不到任何警告。
-
Numpy参数
Numpy有方便处理无效值的方法:
numpy.seterr
,它是为Python进程定义的,决定它如何处理浮点型错误。import numpy as np np.seterr(all='raise') # define before your code. print("numpy test:") a = np.float64(1.0) b = np.float64(0.0) val = a / b # this will now raise an exception instead of a warning. print(val)
不过这也会避免浮点数的溢出问题:
import numpy as np np.seterr(all='raise') # define before your code. print("numpy overflow test:") a = np.float64(10) b = np.float64(1000) val = a ** b # this will now raise an exception print(val)
不过无法避免传递问题:
import numpy as np np.seterr(all='raise') # define before your code. print("numpy propagation test:") a = np.float64('NaN') b = np.float64(1.0) val = a + b # this will neither warn nor raise anything print(val)
-
Tensorflow参数
Tensorflow会增加检查以侦测和处理无效值:
tf.add_check_numerics_ops
和tf.check_numerics
,然而,他们会增加Tensorflow图表处理,增加运算时间。import tensorflow as tf print("tensorflow test:") a = tf.constant(1.0) b = tf.constant(0.0) c = a / b check_nan = tf.add_check_numerics_ops() # add after your graph definition. sess = tf.Session() val, _ = sess.run([c, check_nan]) # this will now raise an exception print(val) sess.close()
这也会避免浮点数溢出问题:
import tensorflow as tf print("tensorflow overflow test:") check_nan = [] # the list of check_numerics operations a = tf.constant(10) b = tf.constant(1000) c = a ** b check_nan.append(tf.check_numerics(c, "")) # check the 'c' operations sess = tf.Session() val, _ = sess.run([c] + check_nan) # this will now raise an exception print(val) sess.close()
捕捉传播问题:
import tensorflow as tf print("tensorflow propagation test:") check_nan = [] # the list of check_numerics operations a = tf.constant('NaN') b = tf.constant(1.0) c = a + b check_nan.append(tf.check_numerics(c, "")) # check the 'c' operations sess = tf.Session() val, _ = sess.run([c] + check_nan) # this will now raise an exception print(val) sess.close()
-
VecChecNan包装器
为查明无效值源自何时何处,stable-baselines提出
VecChecknan
包装器。它会监控行动、观测、奖励,指明那种行动和观测导致了无效值以及从何处出现。
import gym from gym import spaces import numpy as np from stable_baselines import PPO2 from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan class NanAndInfEnv(gym.Env): """Custom Environment that raised NaNs and Infs""" metadata = {'render.modes': ['human']} def __init__(self): super(NanAndInfEnv, self).__init__() self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) def step(self, _action): randf = np.random.rand() if randf > 0.99: obs = float('NaN') elif randf > 0.98: obs = float('inf') else: obs = randf return [obs], 0.0, False, {} def reset(self): return [0.0] def render(self, mode='human', close=False): pass # Create environment env = DummyVecEnv([lambda: NanAndInfEnv()]) env = VecCheckNan(env, raise_exception=True) # Instantiate the agent model = PPO2('MlpPolicy', env) # Train the agent model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment.
-
RL模型超参数
依据你的超参数,NaN可能更经常出现。一个极好的例子:https://github.com/hill-a/stable-baselines/issues/340
要明白,虽然在大多数案例中默认超参数看起来可以跑通,不过在你的环境下很难最优。如果是这样,搞清楚每个超参数如何影响模型,以便你可以调参以得到稳定模型。或者,你可以尝试自动调参(参见RL Zoo)。
-
数据集中的缺失值
如果你的环境产生自外部数据集,确保数据集中不含NaNs。因为有时候数据集中会用NaNs代替缺失值。
这里有一些关于如何查找NaNs的阅读材料:点击链接。
以及用其他方式查找缺失值:点击链接。