作为图像分类处理的经典数据集,在刚入门TensorFlow后,难免就想在cifar-10上展现一下自己。在此就想分享一下,自己在该数据集上遇到的问题。
首先,简单的介绍一下这个数据集。
一. CIFAR-10
① CIFAR-10数据集包含60000个32*32的彩色图像,共有10类。有50000个训练图像和10000个测试图像。
数据集分为5个训练块和1个测试块,每个块有10000个图像。测试块包含从每类随机选择的1000个图像。训练块以随机的顺序包含这些图像,但一些训练块可能比其它类包含更多的图像。训练块每类包含5000个图像。
②data——1个10000*3072大小的uint8s数组。数组的每行存储1张32*32的图像。第1个1024包含红色通道值,下1个包含绿色,最后的1024包含蓝色。图像存储以行顺序为主,所以数组的前32列为图像第1行的红色通道值。
labels——1个10000数的范围为0~9的列表。索引i的数值表示数组data中第i个图像的标签。
③数据集中包含另外1个叫batches.meta的文件。它也包含1个Python字典对象。有如下列元素:
label_names——1个10元素的列表,给labels中的数值标签以有意义的名称。例如,label_names[0] == “airplane”, label_names[1] == “automobile”等。
下图显示的是数据集的类,以及每一类中随机挑选的10张图片:
二、CIFAR-10数据集解析
官方给出了多个CIFAR-10数据集的版本,以下是链接:
Version | Size | md5sum |
---|---|---|
CIFAR-10 python version | 163 MB | c58f30108f718f92721af3b95e74349a |
CIFAR-10 Matlab version | 175 MB | 70270af85842c9e89bb428ec9976c926 |
CIFAR-10 binary version (suitable for C programs) | 162 MB | c32a1d4ab5d03f1284b67883e8d87530 |
此处我们下载python版本。
下载完成后,解压,得到如下目录结构的文件夹:
其中:
名称 | 作用 |
---|---|
batches.meta | 程序中不需要使用该文件 |
data_batch_1 | 训练集的第一个batch,含有10000张图片 |
data_batch_2 | 训练集的第二个batch,含有10000张图片 |
data_batch_3 | 训练集的第三个batch,含有10000张图片 |
data_batch_4 | 训练集的第四个batch,含有10000张图片 |
data_batch_5 | 训练集的第五个batch,含有10000张图片 |
readme.html | 网页文件,程序中不需要使用该文件 |
test_batch | 测试集的batch,含有10000张图片 |
上述文件结构中,每一个batch文件包含一个python的字典(dict)结构,结构如下:
名称 | 作用 |
---|---|
b’data’ | 是一个10000x3072的array,每一行的元素组成了一个32x32的3通道图片,共10000张 |
b’labels’ | 一个长度为10000的list,对应包含data中每一张图片的label |
b’batch_label’ | 这一份batch的名称 |
b’filenames’ | 一个长度为10000的list,对应包含data中每一张图片的名称 |
真正重要的两个关键字是data和labels,剩下的两个并不是十分重要。
接下来就是我们怎么获取数据了。(这里,我们是先将Python版的CIFAR-10数据先下载下来了,不要问我怎么下,上面有链接哦)
import numpy as np import pickle import tensorflow as tf import platform import random import cv2 class Get_cifar: def __init__(self): self.load_cifar10('./cifar-10-python/cifar-10-batches-py') #数据存放的路径 self._split_train_valid(valid_rate=0.9) #将部分数据作为验证集 self.n_train = self.train_images.shape[0] #训练集的数量 self.n_valid = self.valid_images.shape[0] #验证集的数量 self.n_test = self.test_images.shape[0] #测试集的数量 def load_cifar10(self, directory): # 读取训练集 print("开始读取数据集") images, labels = [], [] for filename in ['%s/data_batch_%d' % (directory, j) for j in range(1, 6)]: with open(filename, 'rb') as fo: if 'Windows' in platform.platform(): cifar10 = pickle.load(fo, encoding='bytes') #pickle.load 返回的是一个字典类型 elif 'Linux' in platform.platform(): cifar10 = pickle.load(fo) for i in range(len(cifar10[b"labels"])): image = np.reshape(cifar10[b"data"][i], (3, 32, 32)) image = np.transpose(image, (1, 2, 0)) #变成32*32*3的 image = image.astype(float) images.append(image) labels += cifar10[b"labels"] images = np.array(images, dtype='float') labels = np.array(labels, dtype='int') labels = tf.one_hot(indices=labels, depth=10) self.train_images, self.train_labels = images, labels print("train_images.shape---------------:", images.shape) print("train_labels.shape---------------:", labels.shape) # 读取测试集 images, labels = [], [] for filename in ['%s/test_batch' % (directory)]: with open(filename, 'rb') as fo: if 'Windows' in platform.platform(): cifar10 = pickle.load(fo, encoding='bytes') elif 'Linux' in platform.platform(): cifar10 = pickle.load(fo) for i in range(len(cifar10[b"labels"])): image = np.reshape(cifar10[b"data"][i], (3, 32, 32)) image = np.transpose(image, (1, 2, 0)) image = image.astype(float) images.append(image) labels += cifar10[b"labels"] images = np.array(images, dtype='float') labels = np.array(labels, dtype='int') labels = tf.one_hot(indices=labels, depth=10) self.test_images, self.test_labels = images, labels print("test_images.shape---------------:", images.shape) print("test_labels.shape---------------:", labels.shape) def _split_train_valid(self, valid_rate=0.9): images, labels = self.train_images, self.train_labels thresh = int(images.shape[0] * valid_rate) self.train_images, self.train_labels = images[0:thresh, :, :, :], labels[0:thresh] self.valid_images, self.valid_labels = images[thresh:, :, :, :], labels[thresh:] def data_augmentation(self, images, mode='train', flip=False, crop=False, crop_shape=(24, 24, 3), whiten=False, noise=False, noise_mean=0, noise_std=0.01): # 图像切割 if crop: if mode == 'train': images = self._image_crop(images, shape=crop_shape) elif mode == 'test': images = self._image_crop_test(images, shape=crop_shape) # 图像翻转 if flip: images = self._image_flip(images) # 图像白化 if whiten: images = self._image_whitening(images) # 图像噪声 if noise: images = self._image_noise(images, mean=noise_mean, std=noise_std) return images def _image_crop(self, images, shape): # 图像切割 new_images = [] for i in range(images.shape[0]): old_image = images[i, :, :, :] left = np.random.randint(old_image.shape[0] - shape[0] + 1) top = np.random.randint(old_image.shape[1] - shape[1] + 1) new_image = old_image[left: left + shape[0], top: top + shape[1], :] new_images.append(new_image) return np.array(new_images) def _image_crop_test(self, images, shape): # 图像切割 new_images = [] for i in range(images.shape[0]): old_image = images[i, :, :, :] left = int((old_image.shape[0] - shape[0]) / 2) top = int((old_image.shape[1] - shape[1]) / 2) new_image = old_image[left: left + shape[0], top: top + shape[1], :] new_images.append(new_image) return np.array(new_images) def _image_flip(self, images): # 图像翻转 for i in range(images.shape[0]): old_image = images[i, :, :, :] if np.random.random() < 0.5: new_image = cv2.flip(old_image, 1) else: new_image = old_image images[i, :, :, :] = new_image return images def _image_whitening(self, images): # 图像白化 for i in range(images.shape[0]): old_image = images[i, :, :, :] new_image = (old_image - np.mean(old_image)) / np.std(old_image) images[i, :, :, :] = new_image return images def _image_noise(self, images, mean=0, std=0.01): # 图像噪声 for i in range(images.shape[0]): old_image = images[i, :, :, :] new_image = old_image for i in range(images.shape[0]): for j in range(images.shape[1]): for k in range(images.shape[2]): new_image[i, j, k] += random.gauss(mean, std) images[i, :, :, :] = new_image return images
这样,我们就可以得到 cifar-10 训练集、测试集对应的图片及它的标签了,标签进行了one-hot编码。在这里呢,我还增加了一个验证集,他将用于神经网络训练时准确率的验证。
接下来我会将该数据集应用于各种典型的神经网络中,尽情期待