tensorflow 读取cifar_Tensorflow官网教程:CIFAR-10分类代码阅读

1.题记

因为课程设计要用到TensorFlow,所以这几天在看TensorFlow官方给的几个示例代码,前面几个例子比较简单,看完之后试着重新写了一下,从CIFAR10开始我觉得需要做点笔记了。官网的例子是分成好几个模块实现。

图片来自TensorFlow中文社区

这样的好处是从程序设计的角度来说比较优美,但是从我们看代码的角度来说,不利于阅读逻辑的连贯性。我个人看代码的习惯是先把主线上的执行逻辑理顺了,之后再考虑其他的执行分支和一些细节,其他像属于优化或者扩展,最大可用性等等的代码放到后面看。

所以我的笔记和官网的代码与教程的不同之处在于,我把Demo的主要逻辑捋直了,然后加上了比较详细地注释,感兴趣的同学可以顺着看下来,不用再跳来跳去,应该可以稍微加快阅读这个Demo的速度。

2. cifar10介绍

CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题。任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵盖了10个类别:飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船以及卡车。

谷歌这份Demo的目标是建立一个用于识别图像的相对较小的卷积神经网络。选择CIFAR-10是因为它的复杂程度足以用来检验TensorFlow中的大部分功能,并可将其扩展为更大的模型。与此同时由于模型较小所以训练速度很快,比较适合用来测试新的想法,检验新的技术。

3. 代码

1.导入库

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

from datetime import datetime

import os.path

import time

import os

import re

import sys

import tarfile

import tensorflow.python.platform

from tensorflow.python.platform import gfile

import numpy as np

from six.moves import xrange # pylint: disable=redefined-builtin

import tensorflow as tf

2.部分全局变量

# 定义全局变量

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',

"""Directory where to write event logs """

"""and checkpoint.""")

tf.app.flags.DEFINE_integer('max_steps', 1000000,

"""Number of batches to run.""")

tf.app.flags.DEFINE_boolean('log_device_placement', False,

"""Whether to log device placement.""")

# 基本模型参数

tf.app.flags.DEFINE_integer('batch_size', 128,

"""Number of images to process in a batch.""")

tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',

"""Path to the CIFAR-10 data directory.""")

tf.app.flags.DEFINE_integer('log_frequency', 10,

"""How often to log results to the console.""")

DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'

3.下载数据集

国内网络环境的原因,这段可能执行不成功,大家可以去官网下载好数据集,然后放置在/tmp/cifar10_data路径下,解压后的效果如下:

数据集存放位置

# 检测本地是否有数据集

def maybe_download_and_extract():

"""Download and extract the tarball from Alex's website."""

dest_directory = FLAGS.data_dir # /tmp/cifar10_data

if not os.path.exists(dest_directory):

os.makedirs(dest_directory)

# 从URL中获得文件名

filename = DATA_URL.split('/')[-1]

# 合并文件路径

filepath = os.path.join(dest_directory, filename)

if not os.path.exists(filepath):

# 定义下载过程中打印日志的回调函数

def _progress(count, block_size, total_size):

sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,

float(count * block_size) / float(total_size) * 100.0))

sys.stdout.flush()

# 下载数据集

filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath,

reporthook=_progress)

print()

# 获得文件信息

statinfo = os.stat(filepath)

print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')

# 解压缩

tarfile.open(filepath, 'r:gz').extractall(dest_directory)

maybe_download_and_extract()

4.

# 删除之前训练过程中产生的一些临时文件,并重新生成目录

if gfile.Exists(FLAGS.train_dir):

gfile.DeleteRecursively(FLAGS.train_dir)

gfile.MakeDirs(FLAGS.train_dir)

#定义记录训练步数的变量

global_step = tf.train.get_or_create_global_step()# tf.Variable(0, trainable=False)

5. 导入数据和标签

# 从 CIFAR-10 中导入数据和标签

IMAGE_SIZE = 24

NUM_CLASSES = 10

NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000

NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000

def distorted_inputs(batch_size):

"""

Returns:

i

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值