tensorflow 神兵之 tensorflow.contrib.learn

实验环境

python3.5,建议通过virtualenv安装tensorflow环境,这样会养成良好的习惯,便于项目的部署和移植。docker也可,但virtualenv更轻量级。

代码

代码引用自tensorflow官方github

#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Model training for Iris data set using Validation Monitor."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec

tf.logging.set_verbosity(tf.logging.INFO)

# Data sets
IRIS_TRAINING = os.path.join(os.path.dirname(__file__), "iris_training.csv")
IRIS_TEST = os.path.join(os.path.dirname(__file__), "iris_test.csv")


def main(unused_argv):
  # Load datasets.
  training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float)
  test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float)

  validation_metrics = {
      "accuracy":
          tf.contrib.learn.metric_spec.MetricSpec(
              metric_fn=tf.contrib.metrics.streaming_accuracy,
              prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
              CLASSES),
      "precision":
          tf.contrib.learn.metric_spec.MetricSpec(
              metric_fn=tf.contrib.metrics.streaming_precision,
              prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
              CLASSES),
      "recall":
          tf.contrib.learn.metric_spec.MetricSpec(
              metric_fn=tf.contrib.metrics.streaming_recall,
              prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
              CLASSES)
  }
  validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
      test_set.data,
      test_set.target,
      every_n_steps=50,
      metrics=validation_metrics,
      early_stopping_metric="loss",
      early_stopping_metric_minimize=True,
      early_stopping_rounds=200)

  # Specify that all features have real-value data
  feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

  validation_metrics = {
      "accuracy": MetricSpec(
                          metric_fn=tf.contrib.metrics.streaming_accuracy,
                          prediction_key="classes"),
      "recall": MetricSpec(
                          metric_fn=tf.contrib.metrics.streaming_recall,
                          prediction_key="classes"),
      "precision": MetricSpec(
                          metric_fn=tf.contrib.metrics.streaming_precision,
                          prediction_key="classes")
                        }
  validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
      test_set.data,
      test_set.target,
      every_n_steps=50,
      metrics=validation_metrics,
      early_stopping_metric="loss",
      early_stopping_metric_minimize=True,
      early_stopping_rounds=200)

  # Build 3 layer DNN with 10, 20, 10 units respectively.
  classifier = tf.contrib.learn.DNNClassifier(
      feature_columns=feature_columns,
      hidden_units=[10, 20, 10],
      n_classes=3,
      model_dir="/tmp/iris_model",
      config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))

  # Fit model.
  classifier.fit(x=training_set.data,
                 y=training_set.target,
                 steps=2000,
                 monitors=[validation_monitor])

  # Evaluate accuracy.
  accuracy_score = classifier.evaluate(
      x=test_set.data, y=test_set.target)["accuracy"]
  print("Accuracy: {0:f}".format(accuracy_score))

  # Classify two new flower samples.
  new_samples = np.array(
      [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
  y = list(classifier.predict(new_samples))
  print("Predictions: {}".format(str(y)))


if __name__ == "__main__":
  tf.app.run()

代码解释

  • 开启调试训练日志
tf.logging.set_verbosity(tf.logging.INFO)

tf自带5中级别的日志输出,DEBUG、INFO、WARNING、ERROR、FATAL,这是google工具最常见的。日志指明了输出的最低级别。

  • 记录的关键信息指标metrics
  validation_metrics = {
      "accuracy": MetricSpec(
                          metric_fn=tf.contrib.metrics.streaming_accuracy,
                          prediction_key="classes"),
      "recall": MetricSpec(
                          metric_fn=tf.contrib.metrics.streaming_recall,
                          prediction_key="classes"),
      "precision": MetricSpec(
                          metric_fn=tf.contrib.metrics.streaming_precision,
                          prediction_key="classes")
                        }
  • 构建validation monitor 训练过程中每隔一定的step(every_n_steps参数),使用test_set中数据考察训练效果。记录信息参数metri即为上面构建的metrics。代码中的含义是200steps内loss是否不再下降(minmize参数为true),则提前结束训练。
  validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
      test_set.data,
      test_set.target,
      every_n_steps=50,
      metrics=validation_metrics,
      early_stopping_metric="loss",
      early_stopping_metric_minimize=True,
      early_stopping_rounds=200)
  • 构建分类器 构建深度神经网络DNN的分类器,构建一个输出由feature_columns制定的,三层神经网络(各层神经元数目依次为10、20、10),由于validation montor需要基于 saved checkpoints完成,因此tf.contrib.learn.DNNClassifier构造中要设置config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1),从而保证每隔1s save checkspoints一次,文件存放在model_dir 指定的参数中。需要注意的是,连续运行多次程序,训练时候均会查看该目录,后续的训练会基于最新记录的checkpoints。这样设计很便于早期模型的快速验证。
  classifier = tf.contrib.learn.DNNClassifier(
      feature_columns=feature_columns,
      hidden_units=[10, 20, 10],
      n_classes=3,
      model_dir="/tmp/iris_model",
      config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))
  • 训练模型 调用classifier的fit函数,传入training 数据,不要忘了传入之前设定的monitor参数。monitors参数接受一个由monitor构成的list,因此可以同时加入多个monitor。
  classifier.fit(x=training_set.data,
                 y=training_set.target,
                 steps=2000,
                 monitors=[validation_monitor])
  • 评估模型精度 调用classfier的evaluate函数,传入测试集合,提取关心指标,一般为accuracy。
  accuracy_score = classifier.evaluate(
      x=test_set.data, y=test_set.target)["accuracy"]
  print("Accuracy: {0:f}".format(accuracy_score))

转载于:https://my.oschina.net/u/2276931/blog/829529

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值