在cifar-10数据集中实现kNN分类

import numpy as np
import pickle
from tqdm import tqdm


class KNNClassifier:
    def __init__(self, k):
        self.path = './cifar-10-batches-py/'
        self.trainX = []
        self.trainY = []
        self.testX = None
        self.testY = None
        self.topk = k

    def train(self):
        self.load_cifar()
        # print(np.shape(self.trainX))
        # print(np.shape(self.trainY))
        # print(np.shape(self.testX))
        # print(np.shape(self.testY))

    def test(self):
        res = []
        for i, test_pic in enumerate(self.testX[:100]):
            score = {}
            print(i)
            for j, train_pic in tqdm(enumerate(self.trainX)):
                score[j] = (np.sum(np.abs(test_pic - train_pic)))

            # top-k
            score = sorted(score.items(), key=lambda x: x[1]) # min->max
            top_k = {}
            for m in range(k):
                if self.trainY[score[m][0]] not in top_k.keys():
                    top_k[self.trainY[score[m][0]]] = 1
                else:
                    top_k[self.trainY[score[m][0]]] += 1
            best_idx = np.argmax(list(top_k.values()))
            res.append(list(top_k.keys())[best_idx])
        correct = np.count_nonzero(res == self.testY[:100])
        print("acc: {acc}".format(acc=correct / 100))

    def load_cifar(self):
        for i in range(5):
            with open(self.path + 'data_batch_' + str(i + 1), 'rb') as f:
                a = pickle.load(f, encoding='iso-8859-1')  # labels\data\filenames
                self.trainX.append(a['data'])
                self.trainY.append(a['labels'])
        self.trainX = np.reshape(self.trainX, (50000, 32 * 32 * 3))
        self.trainY = np.reshape(self.trainY, (50000,))
        with open(self.path + 'test_batch', 'rb') as f:
            a = pickle.load(f, encoding='iso-8859-1')
            self.testX = np.reshape(a['data'], (len(a['data']), 32 * 32 * 3))
            self.testY = np.reshape(a['labels'], (len(a['data']),))


if __name__ == '__main__':
    k = 100
    knn = KNNClassifier(k) # top-k
    knn.train()
    knn.test()
    # a = {1:2, 3:1, 4:2, 5:0, 0:6}
    # max = np.argmax(list(a.values()))
    # print(max)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值