import numpy as np
import pickle
from tqdm import tqdm
class KNNClassifier:
def __init__(self, k):
self.path = './cifar-10-batches-py/'
self.trainX = []
self.trainY = []
self.testX = None
self.testY = None
self.topk = k
def train(self):
self.load_cifar()
# print(np.shape(self.trainX))
# print(np.shape(self.trainY))
# print(np.shape(self.testX))
# print(np.shape(self.testY))
def test(self):
res = []
for i, test_pic in enumerate(self.testX[:100]):
score = {}
print(i)
for j, train_pic in tqdm(enumerate(self.trainX)):
score[j] = (np.sum(np.abs(test_pic - train_pic)))
# top-k
score = sorted(score.items(), key=lambda x: x[1]) # min->max
top_k = {}
for m in range(k):
if self.trainY[score[m][0]] not in top_k.keys():
top_k[self.trainY[score[m][0]]] = 1
else:
top_k[self.trainY[score[m][0]]] += 1
best_idx = np.argmax(list(top_k.values()))
res.append(list(top_k.keys())[best_idx])
correct = np.count_nonzero(res == self.testY[:100])
print("acc: {acc}".format(acc=correct / 100))
def load_cifar(self):
for i in range(5):
with open(self.path + 'data_batch_' + str(i + 1), 'rb') as f:
a = pickle.load(f, encoding='iso-8859-1') # labels\data\filenames
self.trainX.append(a['data'])
self.trainY.append(a['labels'])
self.trainX = np.reshape(self.trainX, (50000, 32 * 32 * 3))
self.trainY = np.reshape(self.trainY, (50000,))
with open(self.path + 'test_batch', 'rb') as f:
a = pickle.load(f, encoding='iso-8859-1')
self.testX = np.reshape(a['data'], (len(a['data']), 32 * 32 * 3))
self.testY = np.reshape(a['labels'], (len(a['data']),))
if __name__ == '__main__':
k = 100
knn = KNNClassifier(k) # top-k
knn.train()
knn.test()
# a = {1:2, 3:1, 4:2, 5:0, 0:6}
# max = np.argmax(list(a.values()))
# print(max)
在cifar-10数据集中实现kNN分类
最新推荐文章于 2024-04-18 23:09:05 发布