Cifar-10数据集
Cifar-10数据集是由10类32*32的彩色图片组成的数据集,一共有60000张图片,每类包含6000张图片。其中50000张是训练集,1000张是测试集。
数据集的下载地址:http://www.cs.toronto.edu/~kriz/cifar.html
1. 获取每个batch文件中的字典信息
import pickle
def unpickle(file):
fo = open(file,'rb')
dick = pickle.load(fo,encoding='latin1')
fo.close()
return dick
在字典结构中,每一张图片是以被展开的形式存储,即一张32*32*3的图片被展开成3072长度的list,每一个数据的格式为unit8,前1024为红色通道,中间1024为绿色通道,后1024为蓝色通道。
2.图像预处理。对数据进行标准化操作,按照一定比例进行缩放,使其落入一个特定的区域,便于操作处理。提高了处理速度。
import numpy as np
def clean(data):
imgs = data.reshape(data.shape[0],3,32,32) #shape[0] data第一维的长度(数量)
grayscale_imgs = imgs.mean(1)
cropped_imgs = grayscale_imgs[:,4:28,4:28]
img_data = cropped_imgs.reshape(data.shape[0],-1)
img_size = np.shape(img_data)[1]
means = np.mean(img_data,axis=1)
meansT = means.reshape(len(means),1)
stds = np.std(img_data,axis=1)
stdsT = stds.reshape(len(stds),1)
adj_stds = np.maximum(stdsT,1.0/np.sqrt(img_size))
normalized = (img_data - meansT) / adj_stds
return normalized
mean(axis)函数:求平均值。对m*n的矩阵来说
axis=0:压缩行,对各列求平均值,返回1*n矩阵。
axis=1:压缩列,对各行求平均值,返回m*1矩阵。
axis不设置值,对m*n个数求平均值,返回一个实数。
reshape()函数:改变数组的形状。
reshape((2,4)):变为一个二维数组;reshape((2,2,2)):变为一个三维数组
当有一个参数为-1时,会根据另一个参数的维度计算数组的另外一个shape属性值。
如reshape(data.shape[0],-1):行为data.shape[0]行,列自动算出。data.shape[0]:data第一维的长度。
3.图像数据读取
def read_data(directory):
names = unpickle('{}/batches.meta'.format(directory))['label_names']
print('dede')
print('names',names)
print('dede')
data,labels = [],[]
#一个batch一个batch的去读取batch数据
for i in range(1,6):
filename = '{}/data_batch_{}'.format(directory,i)
batch_data = unpickle(filename)
#拼加操作
if len(data) > 0:
data = np.vstack((data,batch_data['data']))
labels = np.hstack((labels,batch_data['labels']))
else:
data = batch_data['data']
labels &#