Deep-Learing-Notes(3)

CIFAR-10 图像识别

CIFAR-10是一个日常生活物体的彩色图像数据集,说起来以前用caffe跑过这个入门,总之这次的对象就是它;

下载

首先用脚本下载;
cifar10_download.py

# coding:utf-8
# 引入当前目录中的已经编写好的cifar10模块
import cifar10
# 引入tensorflow
import tensorflow as tf

# tf.app.flags.FLAGS是TensorFlow内部的一个全局变量存储器,同时可以用于命令行参数的处理
FLAGS = tf.app.flags.FLAGS
# 在cifar10模块中预先定义了f.app.flags.FLAGS.data_dir为CIFAR-10的数据路径
# 我们把这个路径改为cifar10_data
FLAGS.data_dir = 'cifar10_data/'

# 如果不存在数据文件,就会执行下载
cifar10.maybe_download_and_extract()

关于FLAGS,这里写用一个脚本用来认识:

import tensorflow as tf

#三个参数分别是,变量名字,变量值,描述,当然不同类型的变量对应不同类型的函数
tf.app.flags.DEFINE_string('test','hello',"description")
tf.app.flags.DEFINE_integer('int_test',111,"description2")
tf.app.flags.DEFINE_boolean('BOOLtest',True,"description3")

FLAGS=tf.app.flags.FLAGS

def main(_):
     print(FLAGS.test)
     print(FLAGS.int_test)
     print(FLAGS.BOOLtest)
#这一段是用来运行主函数的
if __name__ == '__main__':
        tf.app.run()

cifar10官方代码

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import time

import tensorflow as tf

import cifar10

FLAGS = tf.app.flags.FLAGS
#上个例子已经较为清楚的说明了FLAGS的作用
tf.app.flags.DEFINE_string('train_dir', '/home/t64/Desktop/cifar10_data',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 10000,
                            """Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
                            """How often to log results to the console.""")


def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()  

    #distorted用于数据增强,有缩放裁剪噪声等,但不会改变图像的标签
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    #接下来几排就是标准的套用模型,步骤
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print(format_str % (datetime.now(), self._step, loss_value,
                              examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)


def main(argv=None):  # pylint: disable=unused-argument
  cifar10.maybe_download_and_extract()
  if tf.gfile.Exists(FLAGS.train_dir):
    tf.gfile.DeleteRecursively(FLAGS.train_dir)
  tf.gfile.MakeDirs(FLAGS.train_dir)
  train()


if __name__ == '__main__':
      tf.app.run()

这里写图片描述

可视化工具tensorboard

在进行训练的时候,指令

tensorboard --logdir cifar10_train/

会得到一个地址:http://ubuntu:6006 (举例)
在浏览器打开就能看到训练的详细信息:
这里写图片描述


这一篇是记的很水,六月过完一般才写了一篇,这个月都在忙着复习,说来惭愧,不是找理由,待修改。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值