CIFAR-10模型训练python版cifar10数据集

在之前的博客中已经对CIFAR-10做了整体的解析,但是目前从tensorflow/models/tree/master/tutorials/image/cifar10中下载下来,运行cifar10_train.py后训练的是binary(适用于C语言)版的数据集。

那么想训练CIFAR-10 python version数据集该怎么修改代码呢?

其实主要需要修改的部分是cifar10_input.py文件。因为python版本的数据集形式不相同,具体格式请上Alex官网的The CIFAR-10 dataset去了解。因为格式不同,导入数据集的代码部分对于数据集的解析也就不相同。python版如下:

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

这里就不罗嗦啦,直接向大家奉上整个代码:

from __future__ import print_function
import os
import tensorflow as tf
import pickle as pickle
import numpy as np


from PIL import Image

#encoding:utf-8
from scipy import ndimage

# Global constants describing the CIFAR-10 data set
# CIFAR10 image size of 32x32. will distort to 24x24
IMAGE_SIZE = 24

NUM_CLASSES = 10
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000

# train_data_queue = None
# train_labels_queue = None
# train_f_names_queue = None

#读取数据集中的各个文件,按照类型生成相应格式,或列表 或矩阵
def read_cifar10_python_pickles(filenames):
    data = None
    labels = None
    f_names = None
    # Dict Keys from pickle files
    # ['data', 'labels', 'batch_label', 'filenames']
    """
    filenames = [os.path.join(data_dir, 'data_batch_%d' % i)
                 for i in xrange(1, 6)]"""
    for pickle_file in filenames:
        if not tf.gfile.Exists(pickle_file):
            raise ValueError('Failed to find file: ' + pickle_file)
        with open(pickle_file, 'rb') as p:
            # pickle.load(file,*,fix_imports=True, encoding="ASCII", errors="strict")
            #必填参数file必须以二进制可读模式打开,即“rb”,其他都为可选参数
            save = pickle.load(p,encoding='iso-8859-1')
            s_data = save['data']
            s_labels = np.array(save['labels'])
            s_f_names = np.array(save['filenames'])
            # 删除列表
            del save
            print('data set', s_data.shape, s_labels.shape)

            #numpy提供了numpy.append(arr, values, axis=None)函数。对于参数规定,
            # 要么一个数组和一个数值;要么两个数组,不能三个及以上数组直接append拼接。append函数返回的始终是一个一维数组。
            data = np.append(data, s_data, axis=0) if data is not None else s_data
            labels = np.append(labels, s_labels, axis=0) if labels is not None else s_labels
            f_names = np.append(f_names, s_f_names, axis=0) if f_names is not None else s_f_names
    print('Data set: ', data.shape, len(labels))
    return data, labels, f_names


def read_cifar10_python_pickle(filename):
    if not tf.gfile.Exists(filename):
        raise ValueError('Failed to find file: ' + filename)
    with open(filename, 'rb') as p:
        save = pickle.load(p,encoding='iso-8859-1')
        data = save['data']
        labels = np.array(save['labels'])
        f_names = np.array(save['filenames'])
        del save
        print('data set', data.shape, labels.shape)

    return data, labels, f_names


def read_cifar10_to_queue(filenames):

    data, labels, f_names = read_cifar10_python_pickles(filenames)

    # def input_producer(input_tensor,
    #                    element_shape=None,
    #                    num_epochs=None,
    #                    shuffle=True,
    #                    seed=None,
    #                    capacity=32,
    #                    shared_name=None,
    #                    summary_name=None,
    #                    name=None,
    #                    cancel_op=None):
    #这个地方是将数据按照类型作用进行生成队列
    data_queue = tf.train.input_producer(data, shuffle=False)
    labels_queue = tf.train.input_producer(labels, shuffle=False)
    f_names_queue = tf.train.input_producer(f_names, shuffle=False)

    return data_queue, labels_queue, f_names_queue


def read_cifar10_reader(data_q, labels_q):
    #dequeue,函数名,用于移除每个匹配元素的指定队列中的第一个函数,并执行被移除的函数。
    #将元素从队列中移出。如果在执行该操作时队列已空,
    #那么将会阻塞直到元素出列,返回出列的tensors的tuple
    return data_q.dequeue(), labels_q.dequeue()


def read_cifar10(filenames):
    class CIFAR10Record(object):
        pass
    result = CIFAR10Record()
    label_bytes = 1  # 2 for CIFAR-100
    result.height = 32
    result.width = 32
    result.depth = 3

    data_q, label_q, _ = read_cifar10_to_queue(filenames)
    data, label = read_cifar10_reader(data_q, label_q)
    print(data.get_shape(), data.dtype)
    print(label.get_shape(), label.dtype)
    result.label = tf.cast(label, tf.int32)#uint8转变成int32数据类型
    depth_major = tf.reshape(data, [result.depth, result.height, result.width])
    # Convert from [depth, height, width] to [height, width, depth].
    result.uint8image = tf.transpose(depth_major, [1, 2, 0])
    print(depth_major.get_shape(), depth_major.dtype)
    print(result.label.get_shape(), result.label.dtype)
    return result

# 构建一个排列后的一组图片和分类
def _generate_image_and_label_batch(image, label, min_queue_examples,
                                    batch_size, shuffle):
    """Construct a queued batch of images and labels.
  Args:
    image: 3-D Tensor of [height, width, 3] of type.float32.
    label: 1-D Tensor of type.int32
    min_queue_examples: int32, minimum number of samples to retain
      in the queue that provides of batches of examples.
    batch_size: Number of images per batch.
    shuffle: boolean indicating whether to use a shuffling queue.
  Returns:
    images: Images. 4D tensor of [batch_size, height, width, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  """
    # Create a queue that shuffles the examples, and then
    # read 'batch_size' images + labels from the example queue.
    num_preprocess_threads = 16
    if shuffle:
        images, label_batch = tf.train.shuffle_batch(
            [image, label],
            batch_size=batch_size,
            num_threads=num_preprocess_threads,
            capacity=min_queue_examples + 3 * batch_size,
            min_after_dequeue=min_queue_examples)
    else:
        # tf.train.batch(tensors, batch_size, num_threads=1, capacity=32,
        # enqueue_many=False, shapes=None, dynamic_pad=False,
        # allow_smaller_final_batch=False, shared_name=None, name=None)
        # 这里是用队列实现,已经默认使用enqueue_runner将enqueue_runner加入到Graph'senqueue_runner集合中
        # 其默认enqueue_many=False时,输入的tensor为一个样本【x,y,z】,输出为Tensor的一批样本
        # capacity:队列中允许最大元素个数
        images, label_batch = tf.train.batch(
            [image, label],
            batch_size=batch_size,
            num_threads=num_preprocess_threads,
            capacity=min_queue_examples + 3 * batch_siz
  • 2
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一摩尔自由

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值