批量读取 CIFAR10 数据集
本文重点介绍通过拆分原始数据集来加载和训练神经网络模型。 当整个数据集对于本地 RAM 来说太大并且必须在使用“model.fit”训练模型之前拆分
背景 Background
我们通常这样加载 CIFAR10 图像数据集
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
(train_data, train_labels), (test_data, test_labels) = datasets.cifar10.load_data()
load_data()
函数会自动下载数据集并将其存储在 ~/.keras/datasets/cifar-10-batches-py/
中,这就是您可以在文件夹中找到的内容。 查看 load_data()
的源代码,它会加载所有批次并让我们将数据很好地返回到 train_data, train_labels), (test_data, test_labels)
,但是如果您想一一阅读它们怎么办? 你会怎么做,这就是我们今天要探讨的。
Pickle
我们用python的pickle
模块来加载数据。首先定义这个函数
def load_pickle(filename):
""" load correct version of pickle """
version = platform.python_version_tuple()
if version[0] == '2':
return pickle.load(filename)
elif version[0] == '3':
return pickle.load(filename, encoding='latin1')
raise ValueError("invalid python version: {}".format(version))