CIFAR10数据集介绍
CIFAR10数据集包括10类图像,每张图像的大小为32*32,包含如上图的十个类别的对象。每个类都包含6000张图片,总共有60000张图片,数据集平衡。其中,训练组图像包含50000张图片,测试集包含10000张图像。
数据集的下载
数据集地址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
也可以使用pytorch中的方法来获取数据集:
trainset = torchvision.datasets.CIFAR10(root='存储路径',
train=True,
download=True,
transform = transform,
)
testset = torchvision.datasets.CIFAR10(root='存储路径',
train=False,
download=True,
transform = transform,
)
下载后的数据集如下:
包含五个训练batch和一个测试batch,每个batch包含一万张图片。在做深度学习训练的时候直接从batch中读取数据就好,也可以转换为PNG或者JPG图片格式来再进行读取和查看图像数据。
读取代码如下:
# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/5 下午1:02
import cv2
import numpy as np
from six.moves import cPickle as pickle
#解压缩二进制文件
def unpack(file):
fo = open(file, "rb")
dict = pickle.load(fo,encoding='iso-8859-1')
fo.close()
return dict
## unpack trainset
for i in range(1,6):
data_name = "训练batch路径" + str(i)
Xtr = unpack(data_name)
print(data_name + 'is loading....')
for j in range(10000):
img = np.reshape(Xtr['data'][j],(3,32,32))
img = img.transpose(1,2,0)
img_name = 'train/' + str(Xtr['labels'][j]) + '_' + str(j+ (i-1)*10000) + '.jpg'
cv2.imwrite(img_name,img)
print(data_name + 'is loaded....')
testXtr = unpack('测试batch路径')
for i in range(0,10000):
img = np.reshape(testXtr['data'][i],(3,32,32))
img = img.transpose(1,2,0)
img_name = 'test/' + str(testXtr['labels'][i]) + '_' + str(i) + '.jpg'
cv2.imwrite(img_name, img)
在python3中解压二进制文件要带上这一句:
dict = pickle.load(fo,encoding='iso-8859-1')
否则会出现编码错误。