【每天码一点】cs231n作业一:KNN

自己想的加载数据集,在实现KNN无循环分类时有错,感觉是加载出来的格式不对。

还是老实用提供的代码,就直接参考大佬的https://zhuanlan.zhihu.com/p/30748903,已完成。

错例:

import numpy as np
import matplotlib.pyplot as plt
import random

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

X_train,y_train,X_test,y_test=[],[],[],[]

for i in range(1,6):
    data=unpickle('datasets/pankrzysiu-cifar10-python-momodel/cifar-10-batches-py/data_batch_'+str(i))
    X_train.extend(data[b'data'])
    y_train.extend(data[b'labels'])

data=unpickle('datasets/pankrzysiu-cifar10-python-momodel/cifar-10-batches-py/test_batch')
X_test=data[b'data']
y_test=data[b'labels']

print(data.keys())
print(type(X_train)
#dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
#<class 'list'>

X_train,y_train=np.array(X_train),np.array(y_train)
X_test,y_test=np.array(X_test),np.array(y_test)
#<class 'numpy.ndarray'>

classes=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes=len(classes)

#胡乱显示了下
import cv2 as cv
from PIL import Image
im=Image.fromarray(X_train[3].reshape(32,32,3))
im.save('1.png')
# cv.imwrite('1'+'.jpg',X_train[3].reshape(32,32,3))
img=cv.imread('1.png',0)
plt.figure()
plt.imshow(im)
plt.xticks([]),plt.yticks([])
plt.show()


X_train = np.reshape(X_train, (X_train.shape[0], -1))
X_test = np.reshape(X_test, (X_test.shape[0], -1))
print(X_train.shape, X_test.shape)
#(50000, 3072) (10000, 3072)

可以了~菜鸟高兴~~~~

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

X_train,y_train,X_test,y_test=[],[],[],[]
for i in range(1,6):
    data=unpickle('datasets/pankrzysiu-cifar10-python-momodel/cifar-10-batches-py/data_batch_'+str(i))
    X=data[b'data'].reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")#就是这个变换(10000张图片,3通道,H,W)
    Y=np.array(data[b'labels'])
    X_train.append(X)
    y_train.append(Y)


X_train=np.array(X_train)
y_train=np.array(y_train)

classes=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes=len(classes)

plt.figure()
plt.imshow(X_train[5].astype('uint8'))
plt.xticks([]),plt.yticks([])
plt.show()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值