kNN在CIFAR10上的应用

1. 获取CIFAR10

CIFAR10是一个10分类的图片数据集,主页在这里,作者使用python版本的数据集。

2. 加载数据集

在主页上已有加载数据集的代码,数据集分成了5个训练用的batch和1个test batch,每个batch有10000张32x32x3的图片,还有一个batches.meta文件装着label对应的名字。


不妨贴出我的代码:

def load_data(root, batch):
    ''' @brief: There are 5 batches and a test-batch
        in ../datasets/cifar-10. 每个batch打开有key:['data,
        labels, batch_label, filenames']
        @param batch: batch-n/test-batch
    '''
    batch_path = os.path.join(root, batch)
    with open(batch_path, 'rb') as f:
        dataset = pickle.load(f)
    return dataset

def load_label_names(root):
    ''' @brief: 装载batches.meta,包含了label_names '''
    meta_path = os.path.join(root, 'batches.meta')
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    return meta['label_names']


在dataset这个dict里最有用的是data和labels两个key,分别对应10000x3072的图像数据和10000个标签。


3. kNN

kNN的思想是对需要确定类别的数据,在已知类别的数据集上找到与它距离最近的k个数据,根据这k个数据各自属于的类别对新数据的类别进行投票,少数服从多数。就像这样(百度百科贴过来的):


写成代码就像这样:

class kNN:
    ''' 实现kNN分类器 '''
    def __init__(self):
        self.Xtr = None
        self.Ytr = None

    def __init__(self, X, Y):
        self.Xtr = X
        self.Ytr = Y

    def train(self, X, Y):
        self.Xtr = X
        self.Ytr = Y

    def predict(self, x, k=1):
        distances = np.sum((self.Xtr - x)**2, axis=1)
        k_labels = [self.Ytr[x] for x in np.argsort(distances)][:k]
        u, counts = np.unique(k_labels, return_counts=True)
        return u[np.argmax(counts)]


嘛,其实主要的东西都在predict里,这样写只是个套路。

kNN能够设置的参数就是两个,距离度量和k值,作者写的距离是欧几里得距离,就是相减平方加和,也可以用其他距离试试。

k值可以通过实验确定,作者先在一个batch上玩玩,将batch分为训练集、验证集和测试集,分割比例是7:2:1,先通过验证集确定一个较好的k值。

代码:

    acc_vs_k = []
    knn = kNN(train_data, train_labels)
    k_list = range(1,11)
    acc_list = []
    for k in k_list:
        correct_num = 0
        now = time.time()
        for i in xrange(val_size):
            pred_label = knn.predict(val_data[i], k)
            true_label = val_labels[i]
            if pred_label == true_label:
                correct_num += 1
        acc = correct_num * 1.0 / val_size
        acc_list.append(acc)

把k值和accuracy对应的图显示出来


在1~10这个范围内最优的k值是5


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值