【深度学习初探】Day11 - Cirfa10数据集
文章目录
Cifar-10数据集是一个比较出名的数据集。它一共有10类图片,每一类有6000张图片,加起来就是60000张图片,每张图片的尺寸是32x32,图片是彩色图。今天我们来了解一下这个数据集,然后对其进行可视化操作。
一、Cirfa10的数据构成
Cirfa10数据集被分为5个训练批次和1个测试批次,每一批10000张图片。测试批次包含10000张图片,是由每一类图片随机抽取出1000张组成的集合。剩下的50000张图片每一类的图片数量都是5000张,训练批次是由剩下的50000张图片打乱顺序,然后随机分成5份,所以可能某个训练批次中10个种类的图片数量不是对等的,会出现一个类的图片数量比另一类多的情况。
Cirfa10数据集的文件构成如下,test_batch是测试批次,data_batch_1(2、3、4、5)是训练批次。
Cirfa10数据集包含10个种类,10个种类如下所示,每个种类随机选取了10张图:
二、Cirfa10数据集的导入
2.1 unpickle函数
Cirfa10数据集的导入,我们可以用到pickle包,将数据集通过文件打开的形式,导入数据,其中得到的数据存储格式是字典(dict)。我们编写一个 unpickle( ) 函数,用于导入数据集,如下所示:
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
通过unpickle( )函数将每一批次的训练文件导入。
2.2 Cirfa10数据格式
Cirfa10的每个data_batch中的数据包含4个字典键,分别是:data、labels、batch_label和filenames,如果需要查看,直接使用print(dict)就可以打印出来。各键值对的说明如下:
这里要注意,我们访问字典键时,要在索引[ ]的字符串前加 b ,因为我们导入的数据是 bytes 编码。
data_batch = unpickle("data_batch_1") # 打开data_batch_1
cifar_data = data_batch[b'data'] # 字符串前加b表示是bytes格式
cifar_label = data_batch[b'labels']
我们观察一下 data 字典键,它保存的是图片的像素值,那是一个 10000 × 3072 的数组,对于这个数组的理解是,10000 行表示10000 张图片,每一张图片有 32 × 32 × 3 个像素值,乘 3 的原因是:每一张图片都是RGB格式的彩色图片(每一张彩色图片都是通过拆分成红色分量矩阵R、绿色分量矩阵G、蓝色分量矩阵B来保存),把3个颜色分量矩阵分别拉伸成1行,然后再拼在一起,就是3072个数值。
我们观察一下 labels 字典键,它保存的是这些图片的标签,这个数据集有10种标签,我们用 0-9 来表示。
那顾名思义,label_names 就是存储标签名字的,如下数字和标签名一一对应。
'airplane'=0
'automobile'=1
'brid'=2
'cat'=3
'deer'=4
'dog'=5
'frog'=6
'horse'=7
'ship'=8
'truck'=9
三、Cirfa-10数据可视化代码
成功导入数据集后,我们可以利用cv2的相关函数,将这些存储在data中的像素值,重新合成RGB图像,并绘制出来,保存在本地文件,从而实现了数据集的可视化。首先我们利用写好的unpickle函数,导入数据集:
import numpy as np
import cv2
def unpickle(file): #打开cirfa10文件的其中一个batch
import pickle
with open("./datasets/5f6b1577787e9d5bb70800a4-momodel/cifar-10-batches-py/"+ file,'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
data_batch = unpickle("data_batch_1") # 打开目录里的data_batch_1
cifar_data = data_batch[b'data'] #注意字典索引字符串前面要加b,因为读入的是bytes格式
cifar_label = data_batch[b'labels']
cifar_data = np.array(cifar_data) # 把字典值转化为Numpy的array格式,便于后续操作
print(cifar_data.shape)
cifar_label = np.array(cifar_label)
print(cifar_label.shape)
输出:
(10000, 3072)
(10000,)
通过导入得到的NumPy数组,cifar_data 是一个二维数组,有10000行,3072列,意思是10000张图片,每张图片的三通道像素 32 × 32 × 3 被平铺;cifar_label 数组,是一个10000个元素的元组,存储了这一万张图片的标签,数据类型是整型数字。
我们将这些图片,重新绘制,利用reshape函数变成原来图片的尺寸,然后还原成彩色图像,使用imwrite函数将图片转储。
label_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
def imwrite_images(k): # k的值可以取1-10000范围内的值
for i in range(k):
image = cifar_data[i] # 每一行是一张图片
image = image.reshape(-1, 1024) # 将数据分为1024列,因为32 × 32是1024,要把三通道合成一张图片,-1表示unspecified value,不知道有几行,反正是分成1024列,如此一来,一张图片应该占据三行
r = image[0,:].reshape(32,32) # 经过1024列重排列后,一张图片应该占三行,红色分量占第一行,把它变回32 × 32的原图大小
g = image[1,:].reshape(32,32) # 绿色分量
b = image[2,:].reshape(32,32) # 蓝色分量
img = np.zeros((32,32,3)) # 初始化,用零填充一个三维数组,前两个维度是32 × 32,第三个维度是3
#RGB还原成彩色图像
img[:,:,0] = r # 把第0维替换成32 × 32的红色分量图
img[:,:,1] = g # 把第1维替换成32 × 32的绿色分量图
img[:,:,2] = b # 把第2维替换成32 × 32的蓝色分量图
cv2.imwrite("D:/[]Python Work Space/PytorchStudy/datasets/5f6b1577787e9d5bb70800a4-momodel/images" + "NO."+str(i) + "class" + str(cifar_label[i]) + str(label_names[cifar_label[i]])+".jpg",img)
print("%d张图片保存完毕"%k)
输出:
100张图片保存完毕
至此,我们完成了对这个数据集的图片可视化。下次,我们试图使用一个基本的网络来做一下图像分类。