相信很多上过cs231n课程的人都尝试过查看cifar10数据集,但是问题特别多,而且报错。下面给出完整的代码,这个是修改过的。
import numpy as np
import os
def load_cifar_batch(filename):
with open(filename,'rb') as f :
#这里面和原来的不一样,原来的,encoding='bytes'但是会报错
datadict=pickle.load(f,encoding='latin1')
x=datadict[b'data']
y=datadict[b'labels']
x=x.reshape(10000,3,32,32).transpose(0,2,3,1).astype('float')
y=np.array(y)
return x,y
def load_cifar10(root):
xs=[]
ys=[]
#range(1,6)的话数据集会特别大大概需要空间2G左右,如果你不想弄这么多就直接range(1,2)读取一个数据集的数据
for b in range(1,6):
f=os.path.join(root,'data_batch_%d' % (b,))
x,y=load_cifar_batch(f)
xs.append(x)
ys.append(y)
Xtrain=np.concatenate(xs) #1
Ytrain=np.concatenate(ys)
del x ,y
Xtest,Ytest=load_cifar_batch(os.path.join(root,'test_batch')) #2
return Xtrain,Ytrain,Xtest,Ytest
#1 将5份训练集转成数组。
#2 将1分测试集转化为数据
将这份这份代码另存为data_utils.py接下来就要进行模型的训练和预测。
下面给出数据集载入模型代码
import numpy as np
from data_utils import load_cifar10
import matplotlib.pyplot as plt
from knn import KNearestNeighbor
x_train,y_train,x_test,y_test=load_cifar10('cifar-10-batches-py')
print('training data shape:',x_train.shape)
print('training labels shape:',y_train.shape)
print('test data shape:',x_test.shape)
print('test labels shape:',y_test.shape)
结果如下
training data shape: (50000, 32, 32, 3)
training labels shape: (50000,)
test data shape: (10000, 32, 32, 3)
test labels shape: (10000,)
最后怎么展示图片,下面是展示图片的代码
classes=['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']
num_claesses=len(classes)
samples_per_class=7
for y ,cls in enumerate(classes):
idxs=np.flatnonzero(y_train==y)
idxs=np.random.choice(idxs,samples_per_class,replace=False)
for i ,idx in enumerate(idxs):
plt_idx=i*num_claesses+y+1
plt.subplot(samples_per_class,num_claesses,plt_idx)
plt.imshow(x_train[idx].astype('uint8'))
plt.axis('off')
if i ==0:
plt.title(cls)
plt.show()