tensorflow的slim模块封装了大量的模块,使用起来简单方便,但有时候想要控制log的打印却也是找不到相应的API,使用slim训练模型时,print语句往往是无法达到预期效果的,下面介绍如何使用train_step控制log
1. slim.train
slim.train()函数定义如下,初始化模型和参数后,需要调用下面的函数进行模型训练:
def train(train_op,
logdir,
train_step_fn=train_step,
train_step_kwargs=_USE_DEFAULT,
log_every_n_steps=1,
graph=None,
master='',
is_chief=True,
global_step=None,
number_of_steps=None,
init_op=_USE_DEFAULT,
init_feed_dict=None,
local_init_op=_USE_DEFAULT,
init_fn=None,
ready_op=_USE_DEFAULT,
summary_op=_USE_DEFAULT,
save_summaries_secs=600,
summary_writer=_USE_DEFAULT,
startup_delay_steps=0,
saver=None,
save_interval_secs=600,
sync_optimizer=None,
session_config=None,
session_wrapper=None,
trace_every_n_steps=None,
ignore_live_threads=False)
接下来需要获取sess的控制权,通过session打印log输出,同时要定义slim.train()中的train_step_fn函数,在该函数中调用train_step函数,然后实现自己想要的逻辑
2. train_step
from tensorflow.contrib.slim.python.slim.learning import train_step
def train_step_fn(session, *xarg, **train_step_kwargs):
total_loss, should_stop = train_step(session, *xarg, **train_step_kwargs)
if train_step_fn.step % 4 ==0:
tf.logging.info(session.run(train_step_fn.lr))
train_step_fn.lr = learning_rate
train_step_fn.step = step
final_loss = slim.learning.train(train_op, TRAIN_LOG,
train_step_fn=train_step_fn,
init_fn=init_fn,
global_step=global_step,
number_of_steps=steps,
save_summaries_secs=60,
save_interval_secs=600,
session_config=sess_config,
)
这样就能获取sess的控制权,完成自定义log的打印
参考:https://github.com/google-research/tf-slim/blob/master/tf_slim/learning.py