TensorFlow系列——在自定义的标准estimator中使用tensorboard及打印中间数据

1、定义hook钩子函数用于获取指定名称的中间数据

1、定义hook钩子类用于获取模型中指定名称的中间数据

class YourOwnHook(tf.train.SessionRunHook):
    def __init__(self):
        np.set_printoptions(suppress=True)
        np.set_printoptions(linewidth=400)

    def before_run(self, run_context):
        """返回SessionRunArgs和session run一起跑"""
        v1 = tf.get_collection('logis')
        prob = tf.get_collection('prob')
        return tf.train.SessionRunArgs(fetches=[v1, prob])
    def after_run(self, run_context, run_values):
        v1, batch_labels = run_values.results
        logger.info("logis value:{}".format(v1))
        print("prob :",batch_labels)

2、标准的自定义的estimator以及设置钩子用于输出到tensorboard以及输出中间值

class MyEstimator(tf.estimator.Estimator):
    def __init__(self,
                          model_dir,
                          hidden_units,
     
  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值