CIFAR10数据集
- 包含50000张训练图片
- 包含10000张测试图片
- 共分10个类别
代码
import tensorflow as tf
import os
import pickle
import numpy as np
cifar_dir = './cifar-10-batches-py'
print(os.listdir(cifar_dir))
train_filenames = [os.path.join(cifar_dir,'data_batch_%d'%i)for i in range(1,6)]
test_filenames = [os.path.join(cifar_dir,'test_batch')]
def load_data(filename):
with open(filename,'rb') as f:
data = pickle.load(f,encoding='bytes')
return data[b'data'],data[b'labels']
class CifarData:
def __init__(self,filenames,need_shuffle):
all_data = []
all_labels = []
for filename in filenames:
data, labels = load_data(filename)
all_data.append(data)
all_labels.append(labels)
self._data = np.vstack(all_data) / 127.5 - 1
self._labels = np.hstack(all_labels)
self._num_examples = self._data.shape[0]
self._index = 0
self._need_shuffle = need_shuffle
if self._need_shuffle:
self.shuffle_data()
def shuffle_data(self):
o = np.random.p