tf.estimator是tensorflow的高阶api,使用下面代码可以实现限制显存,0.8代表使用80%的显存。
session_config = tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)
session_config.gpu_options.per_process_gpu_memory_fraction = 0.8
run_config = tf.estimator.RunConfig(
session_config=session_config,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps)
之前搜到另一种写法,亲测这种写法在replace之后,前面设置的model_ dir和save_checkpoints_steps均失效了,很尴尬~
run_config = tf.estimator.RunConfig(
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps)
session_config = tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)
session_config.gpu_options.per_process_gpu_memory_fraction = 0.8
run_config = tf.estimator.RunConfig().replace(session_config=session_config)
所以还是使用第一种写法吧。