[dataset]MNIST,CIFAR-10

本篇记录深度学习相关的各种数据集的下载


一、MNIST手写体数据集

①可以不用下载,在keras.datasets包中直接调用

from keras.datasets import mnist

class MNISTDataSet:

    def __init__(self, use_split=False, split_rate=0.15, random_state=42, ds_path=None):
        self.use_split = use_split
        self.split_rate = split_rate
        self.random_state = random_state

        self.ds_path = ds_path

        try:
            assert self.ds_path
        except AssertionError:
            raise AssertionError("[-] MNIST DataSet Path is required!")

        from tensorflow.examples.tutorials.mnist import input_data
        self.data = input_data.read_data_sets(self.ds_path, one_hot=True)  # download MNIST

        # training data
        self.train_data = self.data.train

        self.train_images = self.train_data.images
        self.train_labels = self.train_data.labels
        self.valid_images = None
        self.valid_labels = None

        # test data
        self.test_data = self.data.test

        self.test_images = self.test_data.images
        self.test_labels = self.test_data.labels

        # split training data set into train, valid
        if self.use_split:
            self.train_images, self.valid_images, self.train_labels, self.valid_labels = \
                train_test_split(self.train_images, self.train_labels,
                                 test_size=self.split_rate,
                                 random_state=self.random_state)

 

二、CIFAR-10

基本信息

CIFAR-10 是一个包含60000张图片的数据集。其中每张照片为32*32的彩色照片,每个像素点包括RGB三个数值,数值范围 0 ~ 255。

所有照片分属10个不同的类别,分别是 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'

 

其中,5万张作为训练集,1万张作为测试机。
训练集被分为了5批训练和1批测试。每一批都是1万张。

  • 测试集是从每一种分类中随机抽取出来1000张组成。
  • 训练集从10个分类中各自随机抽取5000张,一共5万张

下载数据集

The CIFAR-10 dataset & The CIFAR-100 dataset

data_batch_1 ~ data_batch_5 是划分好的训练数据,每个文件里包含10000张图片,test_batch 是测试集数据,也包含10000张图片。他们的结构是一样的,下面就用 data_batch_1 作为例子进行说明。

代码如下:

class CiFarDataSet:

    @staticmethod
    def unpickle(file):
        import pickle

        # WARN: Only for python3, NOT FOR python2
        assert sys.version_info >= (3, 0)

        with open(file, 'rb') as f:
            return pickle.load(f, encoding='bytes')

    def __init__(self, height=32, width=32, channel=3,
                 use_split=False, split_rate=0.2, random_state=42, ds_name="cifar-10", ds_path=None):

        """
        # General Settings
        :param height: input image height, default 32
        :param width: input image width, default 32
        :param channel: input image channel, default 3 (RGB)
        - in case of CIFAR, image size is 32 x 32 x 3 (HWC).

        # Pre-Processing Option
        :param use_split: training DataSet splitting, default True
        :param split_rate: image split rate (into train & test), default 0.2
        :param random_state: random seed for shuffling, default 42

        # DataSet Option
        :param ds_name: DataSet's name, default cifar-10
        :param ds_path: DataSet's path, default None
        """

        self.height = height
        self.width = width
        self.channel = channel

        self.use_split = use_split
        self.split_rate = split_rate
        self.random_state = random_state

        self.ds_name = ds_name
        self.ds_path = ds_path  # DataSet path
        self.n_classes = 10     # DataSet the number of classes, default 10

        self.train_images = None
        self.valid_images = None
        self.test_images = None

        self.train_labels = None
        self.valid_labels = None
        self.test_labels = None

        try:
            assert self.ds_path
        except AssertionError:
            raise AssertionError("[-] CIFAR10/100 DataSets' Path is required!")

        if self.ds_name == "cifar-10":
            self.cifar_10()   # loading Cifar-10
        elif self.ds_name == "cifar-100":
            self.cifar_100()  # loading Cifar-100
        else:
            raise NotImplementedError("[-] Only 'cifar-10' or 'cifar-100'")

    def cifar_10(self):
        self.n_classes = 10  # labels

        train_batch_1 = self.unpickle("{0}/data_batch_1".format(self.ds_path))
        train_batch_2 = self.unpickle("{0}/data_batch_2".format(self.ds_path))
        train_batch_3 = self.unpickle("{0}/data_batch_3".format(self.ds_path))
        train_batch_4 = self.unpickle("{0}/data_batch_4".format(self.ds_path))
        train_batch_5 = self.unpickle("{0}/data_batch_5".format(self.ds_path))

        # training data & label
        train_data = np.concatenate([
            train_batch_1[b'data'],
            train_batch_2[b'data'],
            train_batch_3[b'data'],
            train_batch_4[b'data'],
            train_batch_5[b'data'],
        ], axis=0)

        train_labels = np.concatenate([
            train_batch_1[b'labels'],
            train_batch_2[b'labels'],
            train_batch_3[b'labels'],
            train_batch_4[b'labels'],
            train_batch_5[b'labels'],
        ], axis=0)

        # Image size : 32x32x3
        train_images = np.swapaxes(train_data.reshape([-1,
                                                       self.height,
                                                       self.width,
                                                       self.channel], order='F'), 1, 2)

        # test data & label
        test_batch = self.unpickle("{0}/test_batch".format(self.ds_path))

        test_data = test_batch[b'data']
        test_labels = np.array(test_batch[b'labels'])

        # image size : 32x32x3
        test_images = np.swapaxes(test_data.reshape([-1,
                                                     self.height,
                                                     self.width,
                                                     self.channel], order='F'), 1, 2)

        # split training data set into train / val
        if self.use_split:
            train_images, valid_images, train_labels, valid_labels = \
                train_test_split(train_images, train_labels,
                                 test_size=self.split_rate,
                                 random_state=self.random_state)

            self.valid_images = valid_images
            self.valid_labels = one_hot(valid_labels, self.n_classes)

        self.train_images = train_images
        self.test_images = test_images

        self.train_labels = one_hot(train_labels, self.n_classes)
        self.test_labels = one_hot(test_labels, self.n_classes)

    def cifar_100(self):
        self.n_classes = 100  # labels

        # training data & label
        train_batch = self.unpickle("{0}/train".format(self.ds_path))

        train_data = np.concatenate([train_batch[b'data']], axis=0)
        train_labels = np.concatenate([train_batch[b'fine_labels']], axis=0)
        train_images = np.swapaxes(train_data.reshape([-1,
                                                       self.height,
                                                       self.width,
                                                       self.channel], order='F'), 1, 2)

        # test data & label
        test_batch = self.unpickle("{0}/test".format(self.ds_path))

        test_data = np.concatenate([test_batch[b'data']], axis=0)
        test_labels = np.concatenate([test_batch[b'fine_labels']], axis=0)
        test_images = np.swapaxes(test_data.reshape([-1,
                                                     self.height,
                                                     self.width,
                                                     self.channel], order='F'), 1, 2)

        # split training data set into train / val
        if self.use_split:
            train_images, valid_images, train_labels, valid_labels = \
                train_test_split(train_images, train_labels,
                                 test_size=self.split_rate,
                                 random_state=self.random_state)

            self.valid_images = valid_images
            self.valid_labels = one_hot(valid_labels, self.n_classes)

        self.train_images = train_images
        self.test_images = test_images

        self.train_labels = one_hot(train_labels, self.n_classes)
        self.test_labels = one_hot(test_labels, self.n_classes)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值