keras中提供的cifar10数据集可能因为网速等问题无法直接下载读取,可以进入官网下载到本地,网址:
http://www.cs.toronto.edu/~kriz/cifar.html,
这里我们下载python版本的。
将下载的tar.gz形式的文件解压,放到想要存放数据文件的文件夹中,这里我的文件存放位置为"/Users/shiruihuo/Documents/study/深度学习/data/cifar10/cifar-10-batches-py"。使用以下脚本可以正确的转换train和test的数据及标签。
# -*- coding: utf-8 -*-
import pickle as p
import numpy as np
import os
def load_CIFAR_batch(filename):
""" 载入cifar数据集的一个batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='bytes')
X = datadict[b'data']
Y = datadict[b'labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
""" 载入cifar全部数据 """