自己想的加载数据集,在实现KNN无循环分类时有错,感觉是加载出来的格式不对。
还是老实用提供的代码,就直接参考大佬的https://zhuanlan.zhihu.com/p/30748903,已完成。
错例:
import numpy as np
import matplotlib.pyplot as plt
import random
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
X_train,y_train,X_test,y_test=[],[],[],[]
for i in range(1,6):
data=unpickle('datasets/pankrzysiu-cifar10-python-momodel/cifar-10-batches-py/data_batch_'+str(i))
X_train.extend(data[b'data'])
y_train.extend(data[b'labels'])
data=unpickle('datasets/pankrzysiu-cifar10-python-momodel/cifar-10-batches-py/test_batch')
X_test=data[b'data']
y_test=data[b'labels']
print(data.keys())
print(type(X_train)
#dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
#<class 'list'>
X_train,y_train=np.array(X_train),np.array(y_train)
X_test,y_test=np.array(X_test),np.array(y_test)
#<class 'numpy.ndarray'>
classes=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes=len(classes)
#胡乱显示了下
import cv2 as cv
from PIL import Image
im=Image.fromarray(X_train[3].reshape(32,32,3))
im.save('1.png')
# cv.imwrite('1'+'.jpg',X_train[3].reshape(32,32,3))
img=cv.imread('1.png',0)
plt.figure()
plt.imshow(im)
plt.xticks([]),plt.yticks([])
plt.show()
X_train = np.reshape(X_train, (X_train.shape[0], -1))
X_test = np.reshape(X_test, (X_test.shape[0], -1))
print(X_train.shape, X_test.shape)
#(50000, 3072) (10000, 3072)
可以了~菜鸟高兴~~~~
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
X_train,y_train,X_test,y_test=[],[],[],[]
for i in range(1,6):
data=unpickle('datasets/pankrzysiu-cifar10-python-momodel/cifar-10-batches-py/data_batch_'+str(i))
X=data[b'data'].reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")#就是这个变换(10000张图片,3通道,H,W)
Y=np.array(data[b'labels'])
X_train.append(X)
y_train.append(Y)
X_train=np.array(X_train)
y_train=np.array(y_train)
classes=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes=len(classes)
plt.figure()
plt.imshow(X_train[5].astype('uint8'))
plt.xticks([]),plt.yticks([])
plt.show()