在之前的博客中已经对CIFAR-10做了整体的解析,但是目前从tensorflow/models/tree/master/tutorials/image/cifar10中下载下来,运行cifar10_train.py后训练的是binary(适用于C语言)版的数据集。
那么想训练CIFAR-10 python version数据集该怎么修改代码呢?
其实主要需要修改的部分是cifar10_input.py文件。因为python版本的数据集形式不相同,具体格式请上Alex官网的The CIFAR-10 dataset去了解。因为格式不同,导入数据集的代码部分对于数据集的解析也就不相同。python版如下:
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
这里就不罗嗦啦,直接向大家奉上整个代码:
from __future__ import print_function
import os
import tensorflow as tf
import pickle as pickle
import numpy as np
from PIL import Image
#encoding:utf-8
from scipy import ndimage
# Global constants describing the CIFAR-10 data set
# CIFAR10 image size of 32x32. will distort to 24x24
IMAGE_SIZE = 24
NUM_CLASSES = 10
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
# train_data_queue = None
# train_labels_queue = None
# train_f_names_queue = None
#读取数据集中的各个文件,按照类型生成相应格式,或列表 或矩阵
def read_cifar10_python_pickles(filenames):
data = None
labels = None
f_names = None
# Dict Keys from pickle files
# ['data', 'labels', 'batch_label', 'filenames']
"""
filenames = [os.path.join(data_dir, 'data_batch_%d' % i)
for i in xrange(1, 6)]"""
for pickle_file in filenames:
if not tf.gfile.Exists(pickle_file):
raise ValueError('Failed to find file: ' + pickle_file)
with open(pickle_file, 'rb') as p:
# pickle.load(file,*,fix_imports=True, encoding="ASCII", errors="strict")
#必填参数file必须以二进制可读模式打开,即“rb”,其他都为可选参数
save = pickle.load(p,encoding='iso-8859-1')
s_data = save['data']
s_labels = np.array(save['labels'])
s_f_names = np.array(save['filenames'])
# 删除列表
del save
print('data set', s_data.shape, s_labels.shape)
#numpy提供了numpy.append(arr, values, axis=None)函数。对于参数规定,
# 要么一个数组和一个数值;要么两个数组,不能三个及以上数组直接append拼接。append函数返回的始终是一个一维数组。
data = np.append(data, s_data, axis=0) if data is not None else s_data
labels = np.append(labels, s_labels, axis=0) if labels is not None else s_labels
f_names = np.append(f_names, s_f_names, axis=0) if f_names is not None else s_f_names
print('Data set: ', data.shape, len(labels))
return data, labels, f_names
def read_cifar10_python_pickle(filename):
if not tf.gfile.Exists(filename):
raise ValueError('Failed to find file: ' + filename)
with open(filename, 'rb') as p:
save = pickle.load(p,encoding='iso-8859-1')
data = save['data']
labels = np.array(save['labels'])
f_names = np.array(save['filenames'])
del save
print('data set', data.shape, labels.shape)
return data, labels, f_names
def read_cifar10_to_queue(filenames):
data, labels, f_names = read_cifar10_python_pickles(filenames)
# def input_producer(input_tensor,
# element_shape=None,
# num_epochs=None,
# shuffle=True,
# seed=None,
# capacity=32,
# shared_name=None,
# summary_name=None,
# name=None,
# cancel_op=None):
#这个地方是将数据按照类型作用进行生成队列
data_queue = tf.train.input_producer(data, shuffle=False)
labels_queue = tf.train.input_producer(labels, shuffle=False)
f_names_queue = tf.train.input_producer(f_names, shuffle=False)
return data_queue, labels_queue, f_names_queue
def read_cifar10_reader(data_q, labels_q):
#dequeue,函数名,用于移除每个匹配元素的指定队列中的第一个函数,并执行被移除的函数。
#将元素从队列中移出。如果在执行该操作时队列已空,
#那么将会阻塞直到元素出列,返回出列的tensors的tuple
return data_q.dequeue(), labels_q.dequeue()
def read_cifar10(filenames):
class CIFAR10Record(object):
pass
result = CIFAR10Record()
label_bytes = 1 # 2 for CIFAR-100
result.height = 32
result.width = 32
result.depth = 3
data_q, label_q, _ = read_cifar10_to_queue(filenames)
data, label = read_cifar10_reader(data_q, label_q)
print(data.get_shape(), data.dtype)
print(label.get_shape(), label.dtype)
result.label = tf.cast(label, tf.int32)#uint8转变成int32数据类型
depth_major = tf.reshape(data, [result.depth, result.height, result.width])
# Convert from [depth, height, width] to [height, width, depth].
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
print(depth_major.get_shape(), depth_major.dtype)
print(result.label.get_shape(), result.label.dtype)
return result
# 构建一个排列后的一组图片和分类
def _generate_image_and_label_batch(image, label, min_queue_examples,
batch_size, shuffle):
"""Construct a queued batch of images and labels.
Args:
image: 3-D Tensor of [height, width, 3] of type.float32.
label: 1-D Tensor of type.int32
min_queue_examples: int32, minimum number of samples to retain
in the queue that provides of batches of examples.
batch_size: Number of images per batch.
shuffle: boolean indicating whether to use a shuffling queue.
Returns:
images: Images. 4D tensor of [batch_size, height, width, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
# Create a queue that shuffles the examples, and then
# read 'batch_size' images + labels from the example queue.
num_preprocess_threads = 16
if shuffle:
images, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
# tf.train.batch(tensors, batch_size, num_threads=1, capacity=32,
# enqueue_many=False, shapes=None, dynamic_pad=False,
# allow_smaller_final_batch=False, shared_name=None, name=None)
# 这里是用队列实现,已经默认使用enqueue_runner将enqueue_runner加入到Graph'senqueue_runner集合中
# 其默认enqueue_many=False时,输入的tensor为一个样本【x,y,z】,输出为Tensor的一批样本
# capacity:队列中允许最大元素个数
images, label_batch = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_siz