CIFAR10 数据集由加拿大 Canadian Institute For Advanced Research 发布,它包含了飞
机、汽车、鸟、猫等共 10 大类物体的彩色图片,每个种类收集了 6000 张32 × 32大小图
片,共 6 万张图片。其中 5 万张作为训练数据集,1 万张作为测试数据集。每个种类样片
如图 所示。
关于数据集的解析可参考官网:(也可直接下载,数据集没多大,很快就下完)
https://www.cs.toronto.edu/~kriz/cifar.html
在 TensorFlow 中,同样地,不需要手动下载、解析和加载 CIFAR10 数据集,通过
datasets.cifar10.load_data()函数就可以直接加载切割好的训练集和测试集。例如:
# 在线下载,加载 CIFAR10 数据集
(x,y), (x_test, y_test) = datasets.cifar10.load_data()
# 删除 y 的一个维度,[b,1] => [b]
y = tf.squeeze(y, axis=1)
y_test = tf.squeeze(y_test, axis=1)
# 打印训练集和测试集的形状
print(x.shape, y.shape, x_test.shape, y_test.shape)
# 构建训练集对象,随机打乱,预处理,批量化
train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(128)
# 构建测试集对象,预处理,批量化
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(128)
# 从训练集中采样一个 Batch,并观察
sample = next(iter(train_db))
print('sample:', sample[0].shape, sample[1].shape,
tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))
此外,CIFAR-100数据集包含100小类,每小类包含600个图像,其中有500个训练图像和100个测试图像。100类被分组为20个大类。每个图像带有1个小类的“fine”标签和1个大类“coarse”标签。
考虑到网速慢的原因,这里直接将数据集下载,并存放在py文件当前目录。
实现源码如下:
# encoding: utf-8
from __future__ import print_function
class Cifar10DataReader():
import os
import random
import numpy as np
import pickle
def __init__(self, cifar_file, one_hot=False, file_number=1):
self.batch_index = 0 # 第i批次
self.file_number = file_number # 第i个文件数
self.cifar_file = cifar_file # 数据集所在dir
self.one_hot = one_hot
self.train_data = self.read_train_file() # 一个数据文件的训练集数据,得到的是一个1000大小的list,
self.test_data = self.read_test_data() # 得到1000个测试集数据
# 读取数据函数,返回dict
def unpickle(self, file):
with open(file, 'rb') as fo:
try:
dicts = self.pickle.load(fo, encoding='bytes')
except Exception as e:
print('load error', e)
return dicts
# 读取一个训练集文件,返回数据list
def read_train_file(self, files=''):
if files:
files = self.os.path.join(self.cifar_file, files)
else:
files = self.os.path.join(self.cifar_file, 'data_batch_%d' % self.file_number)
dict_train = self.unpickle(files)
train_data = list(zip(dict_train[b'data'], dict_train[b'labels'])) # 将数据和对应标签打包
self.np.random.shuffle(train_data)
print('成功读取到训练集数据:data_batch_%d' % self.file_number)
return train_data
# 读取测试集数据
def read_test_data(self):
files = self.os.path.join(self.cifar_file, 'test_batch')
dict_test = self.unpickle(files)
test_data = list(zip(dict_test[b'data'], dict_test[b'labels'])) # 将数据和对应标签打包
print('成功读取测试集数据')
return test_data
# 编码得到的数据,变成张量,并分别得到数据和标签
def encodedata(self, detum):
rdatas = list()
rlabels = list()
for d, l in detum:
rdatas.append(self.np.reshape(self.np.reshape(d, [3, 1024]).T, [32, 32, 3]))
if self.one_hot:
hot = self.np.zeros(10)
hot[int(l)] = 1
rlabels.append(hot)
else:
rlabels.append(l)
return rdatas, rlabels
# 得到batch_size大小的数据和标签
def nex_train_data(self, batch_size=100):
assert 1000 % batch_size == 0, 'erro batch_size can not divied!' # 判断批次大小是否能被整除
# 获得一个batch_size的数据
if self.batch_index < len(self.train_data) // batch_size: # 是否超出一个文件的数据量
detum = self.train_data[self.batch_index * batch_size:(self.batch_index + 1) * batch_size]
datas, labels = self.encodedata(detum)
self.batch_index += 1
else: # 超出了就加载下一个文件
self.batch_index = 0
if self.file_number == 5:
self.file_number = 1
else:
self.file_number += 1
self.read_train_file()
return self.nex_train_data(batch_size=batch_size)
return datas, labels
# 随机抽取batch_size大小的训练集
def next_test_data(self, batch_size=100):
detum = self.random.sample(self.test_data, batch_size) # 随机抽取
datas, labels = self.encodedata(detum)
return datas, labels
if __name__ == '__main__':
import matplotlib.pyplot as plt
Cifar10 = Cifar10DataReader(r'./cifar-10-batches-py2', one_hot=True)
d, l = Cifar10.nex_train_data()
print(len(d))
print(d)
plt.imshow(d[0])
plt.show()
飞机:
当然你也可以选择其他位置:下载下来的数据集放到~/.keras/datasets/ 目录下,然后将文件名改名为cifar-10-batches-py.tar.gz
sudo mv ~/Download/cifar-10-python.tar.gz ~/.keras/datasets/cifar-10-batches-py.tar.gz
解压一下:
tar xvfz cifar-10-batches-py.tar.gz
然后再使用
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
参考文献:
https://www.cnblogs.com/IAMzhuxiaofeng/p/9142582.html
https://www.cs.toronto.edu/~kriz/cifar.html
https://blog.csdn.net/WANG
https://blog.csdn.net/shadowl
https://blog.csdn.net/qq_26593695
https://blog.csdn.net/qq_36895331
https://blog.csdn.net/zhanzi1538/article/details/106878836/