Exception reporting mode: Verbose
Automatic pdb calling has been turned ON
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import tensorflow as tf
1. 定义Cifar-10类
classCifarLoader(object):def__init__(self, source_files):
self._source = source_files
self._i =0
self.images =None
self.labels =Nonedefload(self):
data =[unpickle(f)for f in self._source]
images = np.vstack([d[b"data"]for d in data])
n =len(images)
self.images = images.reshape(n,3,32,32).transpose(0,2,3,1).astype(float)/255
self.labels = one_hot(np.hstack([d[b"labels"]for d in data]),10)return self
defnext_batch(self, batch_size):
x, y = self.images[self._i : self._i + batch_size], self.labels[self._i : self._i + batch_size]
self._i =(self._i + batch_size)%len(self.images)return x, y
2. 定义函数
DATA_PATH ="../dataset/cifar-10-batches-py/"defunpickle(file):withopen(os.path.join(DATA_PATH,file),"rb")as fo:dict= pickle.load(fo, encoding="bytes")returndictdefone_hot(vec, vals=10):
n =len(vec)
out = np.zeros((n, vals))
out[range(n), vec]=1return out
3. 定义数据管理器类
classCifarDataManager(object):def__init__(self):
self.train = CifarLoader(["data_batch_{}".format(i)for i inrange(1,6)]).load()
self.test = CifarLoader(["test_batch"]).load()
4. 显示Cifar-10数据集图片
defdisplay_cifar(images, size):
n =len(images)
plt.figure()
plt.gca().set_axis_off()
im = np.vstack([np.hstack([images[np.random.choice(n)]for i inrange(size)])for i inrange(size)])
plt.imshow(im)
plt.show()
4.1 Cifar-10数据集显示
cifar = CifarDataManager()print("Number of train images: {}".format(len(cifar.train.images)))