# -*- coding: utf-8 -*-
import tensorflow as tf
import os
#cifar_10=input_data
from six.moves import xrange
IMAGE_SISE=24 ## 原图像的尺度为32*32,但根据常识,信息部分通常位于图像的中央,这里定义了以中心裁剪后图像的尺寸
NUM_CLASSES=10
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN=50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL=10000
#读取数据集,并根据数据集使用说明做数据预处理
def read_cifar(filename_queue):
class CIFAR10Record(object):
pass
return CIFAR10Record()
label_bytes=1 #2 for cifar_100
result.height=32 #结果中的行数
result.width=32 #结果中的列数
result.depth=3 #结果中的颜色通道数
image_bytes=result.height*result.width*result.depth
record_bytes=label_bytes+image_bytes
reader=tf.FixedLengthRecordReader(record_bytes)
result.key,value=reader.read(filename_queue) #读取一行记录,从filename_queue队列中获取文件名
record_bytes=tf.decode_raw(value,tf.uint8) #将长度为record_bytes的字符串转换为uint8的向量
result.label==tf.cast(tf.strided_slice(record_bytes,[0],[label_bytes],tf.int32)) # [0]和[label_bytes]分别表示待截取片段的起点和长度 ,转换int32
depth_major=tf.reshape(tf.strided_slice(record_bytes,[label_bytes],[label_bytes+image_bytes]),
[result.depth,result.height,result.width])
result.uint8image=tf.transpose(depth_major,[1,2,0]) # 将 [depth, height, width] 转换为[height, width, depth]
return result
#创建一个队列的批量图和标签
def _generate_image_and_label_batch(image,label,min_queue_examples,batch_size,shuffle):
# 创建一个混排样本的队列,然后从样本队列中读取 'batch_size'数量的 images + labels数据(每个样本都是由images + labels组成)
num_preprocess_threads=16 #预处理采用多线程
if shuffle:
images,label_batch=tf.train.shuffle_batch(
[image,label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples+3*batch_size
)
else:
images,label_batch,=tf.train.batch(
[image,label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples
)
tf.summary.image('images',images) #训练图像可视化
return images,tf.reshape(label_batch,[batch_size])
#使用Reader操作构建扭曲的输入(图像)用作CIFAR训练
def distorted_inputs(data_dir,batch_size):
filenames=[os.path.join(data_dir,'data_batch_%d.bin'%i)
for i in xrange(1,6)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Faile to fine file:'+f)
filename_queue=tf.train.string_input_producer(filenames) #创建文件名队列
read_input=read_cifar(filename_queue) #从文件名队列中读取样本
reshaped_image=tf.cast(read_input.uint8image,tf,float32)
height=IMAGE_SIZE #用于训练神经网络的图像处理,
width=IMAGE_SIZE #对图像进行了很多随机扭曲处理
distorted_image=tf.random_crop(reshaped_image,[height,width,3]) ##随机修建图像的某一块[height,width]区域
distorted_image=tf.image.random_flip_left_right(distorted_image) #随机水平翻转图像
distorted_image=tf.image.random_brightness(distorted_image,max_delta=63) #随机变换图像的亮度
distorted_image=tf.image.random_contrast(distorted_image,lower=0.2,upper=1.8) #随机变换图像的对比度
float_image=tf.image.per_image_standardization(distorted_image) #对图像进行标准化:减去均值并除以像素的方差
#设置张量的形状
float_image.set_shape([height,width,3])
read_input.label.set_shape([1])
#确保随机混排有很好的混合性
min_fraction_of_examples_queue=0.4
min_queue_examples=int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_queue)
print('Filling queue with %d CIFAR images before starting to train.'
'This will take a few minutes.'%min_queue_examples)
#通过构建一个样本队列来生成一批量的图像和标签
return _generate_image_and_label_batch(float_image,read_input.label,min_queue_examples,batch_size,shuffle=True)
#: 使用Reader ops操作构建CIFAR评估的输入
def inputs(eval_data,data_dir,batch_size):
if not eval_data:
filenames=[os.path.join(data_dir,'data_batch_%d.bin'% i)
for i in xrange(1,6)]
num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
else:
filenames=[os.path.join(data_dir,'test_batch.bin')]
num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file:'+f)
filename_queue=tf.train.string_input_producer(filenames)#创建一个文件名队列
read_input=read_cifar(filename_queue)
reshaped_image=tf.cast(read_input.uint8image,tf.float32)
height=IMAGE_SISE
width=IMAGE_SISE
resized_image=tf.image.resize_image_with_crop_or_pad(reshaped_image,
height,width)#裁剪图像的中心
float_image=tf.image.per_image_standardization(resized_image) #标准化:减去均值并除以像素的方差
float_image.set_shape([height,width,3]) #设置张量的形状
read_input.label.set_shape([1])
min_fraction_of_example_in_queue=0.4
min_queue_examples=int(num_examples_per_epoch * min_fraction_of_example_in_queue)
return _generate_image_and_label_batch(float_image,read_input.label,min_queue_examples,batch_size,shuffle=False)
cifar10_input 图像处理