Cifar-10 的介绍可去官网阅读,也可参照我之前整理的《笔记:CIFAR-01 和 CIFAR-100 数据集内容和格式详解》
1. 下载 Cifar-10 数据
本文下载了 Cifar-10 的 Python 语言版本,解压后放在文件夹:...\cifar-10-python\cifar-10-batches-py
中。其中包含如下文件:
2. 读取 Cifar-10 数据
按照官网说明,这些数据可以用如下 Python 代码读取:
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
dict = unpickle('/home/yeping/cifar-10/cifar-10-batches-py/data_batch_1')
print(dict)
------------- run -------------
{
b'batch_label': b'training batch 1 of 5',
b'labels': [6, 9, 9, 4, 1,..., 5],
b'data': array(
[[ 59, 43, 50, ..., 140, 84, 72],
[154, 126, 105, ..., 139, 142, 144],
[255, 253, 253, ..., 83, 83, 84],
...,
[ 62, 61, 60, ..., 130, 130, 131]], dtype=uint8),
b'filenames': [
b'leptodactylus_pentadactylus_s_000004.png',
b'camion_s_000148.png', b'tipper_truck_s_001250.png',
b'american_elk_s_001521.png', b'station_wagon_s_000293.png',
... ,
b'cur_s_000170.png']
}
上面代码读取并打印了 data_batch_1
文件。其实文件的二进制格式不必细究,我们只需要关心遵循官网说明用 Python 读出的内容即可。数据以字典的形式读出,内容包括如下几个部分
字段 | 字段名称 | 数据 | 说明 |
---|---|---|---|
batch_label | 批文件标签 | ‘training batch 1 of 5’ | 当前文件的标题或说明 |
label | 标签 | [6, 9, 9, 4, 1,…, 5] | 图像的分类标签,shape = (10000,) |
data | 数据 | [[ 59, 43, 50, …, 140, 84, 72],… | 图像数据, shape = (10000, 3072) |
filename | 文件名 | … | 图像的文件名 |
数据的 shape 可以用下面的代码读取:
print(np.array(dict[b'labels']).shape) # 转成 np.array
print(dict[b'data'].shape) # 这个本来就是 np.array
------------- run -------------
(10000,)
(10000, 3072)
附:官网关于 Cifar-10 的说明
- 数据——一个10000x3072 numpy的uint8s数组。数组的每一行存储一个32x32彩色图像。前1024个条目包含红色通道值,后1024个条目包含绿色通道值,最后1024个条目包含蓝色通道值。图像以行主顺序存储,因此数组的前32个条目是图像第一行的红色通道值。
- 标签——10000个数字的列表,范围为0-9。索引i处的数字表示数组数据中第i个图像的标签。
3. 显示图形
数组中,每 3072 个字节代表一幅 32x32 分辨率的彩色图像。前1024个条目包含红色通道值,后1024个条目包含绿色通道值,最后1024个条目包含蓝色通道值。图像以行主顺序存储,因此数组的前32个条目是图像第一行的红色通道值。
下面的代码首先把 3072 个字节 reshape 成 3x32x32 格式。但是,plt.imshow 方法现实的图像格式为 32x32x3。因此,用 transpose 把通道位置置换一下,就能正常显示了。
img = dict[b'data'][0]
img = img.reshape(3,32,32)
#print(img)
import matplotlib.pyplot as plt
plt.imshow(np.transpose(img, (1, 2, 0)))
其中,np.transpose(img, (1, 2, 0))
的通道转模式如下:
通道编号 | 转换前的原数据通道 | 转换后的原数据通道位置分布 |
---|---|---|
0 | 0 | 1 |
1 | 1 | 2 |
2 | 2 | 0 |
也就是说,原来的0通道要转换到2通道位置,原来的1通道转换到0通道位置,原来的2通道转移到1通道位置。
4. 总结
代码汇总:
# Cifar-10 官方提供的数据读取代码
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# 读取数据,并显示数据结构
dict = unpickle('/home/yeping/cifar-10/cifar-10-batches-py/data_batch_1')
print(np.array(dict[b'labels']).shape)#adwsdfs
print(dict[b'data'].shape)
# 把数据结构调整成图像(Cifar-10 数据结构与神经网络输入层的结构一致)
img=dict[b'data'][0]
img=img.reshape(3,32,32)
# 显示图像,需要把数据结构调整成正常的图像格式
import matplotlib.pyplot as plt
plt.imshow(np.transpose(img, (1, 2, 0)))
输出结果: