tf.estimator API技术手册(3)——BaselineClassifier(基线分类器)

tf.estimator API技术手册(3)——BaselineClassifier(基线分类器)

(一)简 介
(二)初始化
(三)属 性(Properties)
(四)主要方法(Methods)
    (1)evaluate(评估)
    (2)predict(预测)
    (3)train(训练)

(一)简 介

该类继承自Estimator,定义在tensorflow/python/estimator/canned/baseline.py中,可以用来构建一个简单的基线,该分类器会忽视特征的值并将去学习预测每个标签的平均值。对于单标签问题,它将预测标签中呈现的类的概率分布,多于多标签问题,它将预测每个类的正例的分数。
示例如下:
在这里插入图片描述

(二)初始化

初始化一个基线分类器实例:
在这里插入图片描述参数如下:

  • model_dir:

    保存模型参数、计算图等的目录,也可以从这个目录中加载checkpoints文件以继续训练当前已保存的模型。

  • n_classes:

    标签的种类,默认为两类,必须比1大。注意:类标签是表示类索引的整数,对于任意标签值,首先转换为类索引。

  • weight_column:

    由 tf.feature_column.numeric_column创建的一个字符串或者数字列用来呈现特征列。他将会被乘以example的训练损失。label_vocabulary:可选字符串列表,具有定义标签词汇的大小[n_classes],只支持n_classes大于2.

  • optimizer:

    字符串,TensorFlow优化器对象, 或者可以创建一个用于训练的优化器.。如果没有自行制定,将会使用FtrlOptimizer,并将设定学习率为0.3。

  • config:

    一个运行配置对象,用来配置运行时间。

  • loss_reduction:

    tf.losses包含方法中的一个,用来描述如何减少训练损失,默认使用SUM方法。

(三)属 性(Properties)

  • config
  • model_dir
  • model_fn
    Returns the model_fn which is bound to self.params.
    返回:
    model_fn 附有以下标记: def model_fn(features, labels, mode, config)

(四)主要方法(Methods)

(1)evaluate(评估)

在这里插入图片描述评估函数,使用input_fn给出的评估数据评估训练好的模型,参数列表如下:

  • input_fn:
    一个用来构造用于评估的数据的函数,这个函数应该构造和返回如下的值:一个tf.data.Dataset对象或者一个包含 (features, labels)的元组,它们应当满足model_fn函数对输入数据的要求,在后面的实例中我们会详细介绍。
  • checkpoint_path:
    用来保存训练好的模型
  • name:
    如果用户需要在不同的数据集上运行多个评价,如训练集和测试集,则为要进行评估的名称,不同的评估度量被保存在单独的文件夹中,并分别出现在tensorboard中。

(2)predict(预测)

在这里插入图片描述使用训练好的模型对新实例进行预测,以下为参数列表:

  • input_fn:

    一个用来构造用于评估的数据的函数,这个函数应该构造和返回如下的值:一个tf.data.Dataset对象或者一个包含 (features, labels)的元组,它们应当满足model_fn函数对输入数据的要求,在后面的实例中我们会详细介绍。

  • predict_keys:

    预测函数最终会返回一系列的结果,但我们可以有选择地让其输出,可供选择的keys列表为[‘logits’, ‘logistic’, ‘probabilities’, ‘class_ids’, ‘classes’],如果不指定的话,默认返回所有值。

  • hooks:

    tf.train.SessionRunHook的子类实例列表,在预测调用中用于传回。

  • checkpoint_path:

    训练好的模型的目录

  • yield_single_examples:

    可以选择False或是True,如果选择False,由model_fn返回整个批次,而不是将批次分解为单个元素。当model_fn返回的一些的张量的第一维度和批处理数量不相等时,这个功能是很用的。

(3)train(训练)

在这里插入图片描述
用于训练模型,以下为参数列表:

  • input_fn:

    一个用来构造用于评估的数据的函数,这个函数应该构造和返回如下的值:一个tf.data.Dataset对象或者一个包含 (features, labels)的元组,它们应当满足model_fn函数对输入数据的要求,在后面的实例中我们会详细介绍。

  • hooks:

    tf.train.SessionRunHook的子类实例列表,在预测调用中用于传回。

  • steps:

    模型训练的次数,如果不指定,则会一直训练知道input_fn传回的数据消耗完为止。如果你不想要增量表现,就设置max_steps来替代,注意设置了steps,max_steps必须为None,设置了max_steps,steps必须为None。

  • max_steps:

    模型训练的总次数,注意设置了steps,max_steps必须为None,设置了max_steps,steps必须为None。

  • saving_listeners:

    CheckpointSaverListener对象的列表,用于在检查点保存之前或之后立即运行的回调。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值