Cifar数据集的读取使用方法
下载cifar-10数据集
官网下载链接:
http://www.cs.toronto.edu/~kriz/cifar.html
这里我选择的是python版的cifar数据集。
结合pycharm使用cifar-10数据集
将下载好的数据集解压到pycharm当前项目的目录下
读取使用数据集
import pickle
import os
import numpy as np
def load_CIFAR_batch(filename):
with open(filename, 'rb') as f:
data_dict = pickle.load(f, encoding='latin1')
data = data_dict['data']
label = data_dict['labels']
# reshape, 一维数组转为矩阵10000行3列。每个entries是32x32
# transpose,转置
# astype,复制,同时指定类型
data = data.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
label = np.array(label)
return data, label
def load_CIFAR10(ROOT):
data_list = []
label_list = []
# 训练集batch 1~5
for i in range(1, 6):
file = 'data_batch_{0}'.format(i)
# 搭建数据集的相对路径
f = os.path.join(ROOT, file)
data, label = load_CIFAR_batch(f)
data_list.append(data)
label_list.append(label)
# [ndarray, ndarray] 合并为一个ndarray
data_train = np.concatenate(data_list)
label_train = np.concatenate(label_list)
del data_list, label_list
data_test, label_test = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return data_train, label_train, data_test, label_test
if __name__ == '__main__':
data_train, label_train, data_test, label_test = load_CIFAR10('cifar-10-batches-py')
将数据集的数据进行简单处理后,就可以为我们搭建的模型提供训练集和测试集数据。