cs231n系列(1)-KNN分类器

首先先从NN(nearest-neighbor)分类器开始介绍,下图是使用NN分类器对CIFAR-10数据库进行分类的一个结果。


可以看出,分类出来的效果并不好,如将车分类成了horse类。那么为什么会造成这样的结果呢?是否可以优化呢?是不是意味着NN或者KNN在实际中就没有用处了呢?我们带着这些疑问继续去学习以下内容。

那么NN分类器具体如何比较两张图片呢?在本例中,就是比较32x32x3的像素块。方法就是图像逐个像素比较,最后将差异值全部加起来。换句话说,就是将两张图片先转化为两个向量I_1I_2,然后计算他们的L1距离(1阶范数)

\displaystyle d_1(I_1,I_2)=\sum_p|I^p_1-I^p_2|

计算出的L1距离越小,意味着与训练图越相似,当为0的时候,两张图就是一模一样了。(简单了解下L1距离的意义,具体可以点击链接查看)。

由此,我们可以总结出NN分类器的流程


到这里,大概简明的阐述了NN分类器的原理。那么现在能不能解决我们所提出的问题呢。我们可以看出horse类别的图片背景是大量的黑色,根据NN分类器原理,意味着极有可能有很多黑色背景的图片会被误认为horse类别,结果正好证实了我们这一点。正如我们所料,NN分类器在这个数据集中正确率只有30%左右,和我们人的视觉的识别率有着大大的差距,所以NN分类器在这个方面非常不可靠,同时每一次测试集元素都需要去遍历所有训练集元素,当数据集非常大时,其中的运算开销足以让我们退避三舍。暂且不说时空复杂度的问题,我们有办法提高一些分类的正确率吗?可能你想到了投票,是的,由此我们引出KNN分类器,其实原理是一样的,只是我们去取L1距离或L2距离最小的K个训练集元素,让这K个元素进行投票,票数多的就是当前测试集元素的分类。所以NN分类器也可以看做是K为1的KNN分类器。到此,KNN分类器原理阐述完毕。接下来是py代码给大家参考

# -*- coding: utf-8 -*-
#声明保存文件编码

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def unpickle(file):
    import cPickle
    fo = open(file,'rb')
    dict = cPickle.load(fo)
    fo.close()
    return dict

def load_CIFAR10(file):
    dataTrain = []#训练集数据
    labelTrain = []#训练集标签

    #训练集中有5个文件
    for i in range(1,6):
        dic = unpickle(file+"\\data_batch_"+str(i))
        for item in dic["data"]:
            dataTrain.append(item)
        for item in dic["labels"]:
            labelTrain.append(item)
    
    #加载测试数据和标签
    dataTest = []
    labelTest = []
    dic = unpickle(file+"\\test_batch")
    for item in dic["data"]:
        dataTest.append(item)
    for item in dic["labels"]:
        labelTest.append(item)
    return (dataTrain,labelTrain,dataTest,labelTest)

dataTr,labelTr,dataTe,labelTe = load_CIFAR10("F:\\photoWork\\cs231n\\workplace\\KNN\\cifar-10-batches-py")
#将数组转为numpy类型的array,方便操作
Xtr = np.asarray(dataTr)
Xte = np.asarray(dataTe)
Ytr = np.asarray(labelTr)
Yte = np.asarray(labelTe)
#将数据集拉长成为行向量了
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) 
Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) 

class NearestNeighbor(object):
    def __init__(self):
        pass
    #训练集赋值
    def train(self,X,Y):
        self.Xtr = X
        self.Ytr = Y
    #进行预测
    def predict(self, X):
        num_test = X.shape[0]#获得测试集的个数
        Ypred = np.zeros(num_test,dtype = self.Xtr.dtype)#初始化0矩阵

        for i in xrange(num_test):
            #计算训练集和当前测试图片的L1值
            distances = np.sum(np.abs(self.Xtr - X[i,:]),axis=1)
            #取出最小值,即最相近的图片下标
            min_index = np.argmin(distances)
            #将标签值存入Ypred[i]
            Ypred[i] = self.Ytr[min_index]
        return Ypred

nn = NearestNeighbor()
#训练集赋值
nn.train(Xtr_rows, Ytr)
#获得预测标签数组
Yte_predict = nn.predict(Xte_rows)
print 'accuracy: %f' % ( np.mean(Yte_predict == Yte) )

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值