1、定义hook钩子函数用于获取指定名称的中间数据
1、定义hook钩子类用于获取模型中指定名称的中间数据
class YourOwnHook(tf.train.SessionRunHook):
def __init__(self):
np.set_printoptions(suppress=True)
np.set_printoptions(linewidth=400)
def before_run(self, run_context):
"""返回SessionRunArgs和session run一起跑"""
v1 = tf.get_collection('logis')
prob = tf.get_collection('prob')
return tf.train.SessionRunArgs(fetches=[v1, prob])
def after_run(self, run_context, run_values):
v1, batch_labels = run_values.results
logger.info("logis value:{}".format(v1))
print("prob :",batch_labels)
2、标准的自定义的estimator以及设置钩子用于输出到tensorboard以及输出中间值
class MyEstimator(tf.estimator.Estimator):
def __init__(self,
model_dir,
hidden_units,