机器学习实战:k近邻算法

原理解析

什么是k近邻算法呢?听起来挺高大上的,其实算法思想很简单:

  1. 计算待分类数据与已知分类数据的距离(一般我们取欧几里得距离,就是几何距离)
  2. 对距离从小到大进行一个排序,取距离最小的k个值
  3. 取这k个值中出现频率最高的标签作为分类结果

是不是很浅显的道理,可以用一张图来简单的说明一下原理:

k近邻算法原理图

完全符合我们正常的思路。下面用一个实例进行一个简单的代码实现,不妨假设坐标轴上有两块区域A和B,其中区域A位于点(1,1)附近,区域B位于(0,0)附近,我们创建一个符合该假设的数据集:

def createDataSet():
    group = array([[1.0,1.1], [1.0,1.0], [0,0], [0,0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels

该函数创建了一个包含四个点的数据集,其中两个点位于区域A,两个点位于区域B。给定一个未知点(0.1, 0.1),很显然我们希望程序能预测其位于区域B,下面我们来应用knn来对该未知点进行一个分类。

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, dataSet) - dataSet # tile: tile(A,rep) 功能:重复A的各个维度 rep:A沿着各个维度重复的次数
    sqDiffMat = diffmat**2
    sqDiff = sqDiffMat.sum(axis=1)
    distances = sqDiff**0.5 # 计算输入点与已知各点的距离,得到一个一维数组

    sortedIndices = distances.argsort() # 返回距离递增排序的下标
    classCount = {}
    for i in range(k): # 循环获取最短的k个距离的标签
        votedLabel = labels[sortedIndices[i]]
        classCount[votedLabel] = classCount.get(votedLabel, 0)+1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) # 根据标签出现频率排序,获得最大频率标签即分类结果
    return sortedClassCount[0][0]

运行上面的函数可以得到分类结果为B,符合我们的预期。

实战:手写数字识别

这里我们会用到著名的MNIST数据集。MNIST是一个标定好的手写数字数据集,包含训练数据60000条,测试数据10000,每一组数据包含一个28*28的矩阵,我们会把它展成784*1的向量,其中使用0-255灰度值表示对应像素点处的颜色,然后每一组数据都会有对应标定好的标签,文件名像这样:

mnist.png

首先是数据的预处理,可以参考这篇博文(这个奇怪的文件格式我也没见过),处理过程贴代码:

def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte'
                               % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

然后我们拿到了我们期待的(60000*784)的一个训练矩阵+(60000,)的一个训练标签向量(knn其实没有训练模型这一步,每一轮都是根据已有的数据直接推测),同样(10000*784)+(10000,)的一个测试集。然后就可以套用我们的knn算法开始进行预测啦:

# 加载数据集
t0 = time()
print('loading mnist datasets...')
trainImgs, trainLbs = load_mnist('./mnist', 'train')
testImgs, testLbs = load_mnist('./mnist', 't10k')
print('datasets have loaded successfully after %d seconds' % (time()-t0))
# 在测试集上测试算法的准确度
testLen = testImgs.shape[0]
total = 0
error = 0
startTime = time()
for i in range(testLen):
    total += 1
    predict = knn.classify0(testImgs[i], trainImgs, trainLbs, 10)# 这一行划重点
    print('predict result: %s, real result: %s, %d seconds have passed' % (predict, testLbs[i], time()-startTime))
    if(predict != testLbs[i]): error += 1
    if(total % 100 == 0): print('error rate: %f' % (float(error)/total))

很遗憾,预测结果跟瞎猜没什么两样。。上网查了一下别人的测试,据说成功率很高(使用KNN对MNIST数据集进行实验),有点质疑他的代码执行结果。。。然后对矩阵做了一下归一化处理(好像不这么叫),也就是所有数据都除了一个255:predict = knn.classify0(testImgs[i], trainImgs/255, trainLbs/255, 10),然而也并没有效果(其实也说明我没有完全理解距离的含义,还有些靠猜)。最后只能参照书上的方法,对灰度值做一个二值化处理,代码像这样:

def norm(imageMat, valve=0):
    for imageVec in imageMat:
        for i in range(len(imageVec)):
            imageVec[i] = 1 if imageVec[i]>valve else 0

处理完成后的结果像这样:

二值化处理完成后的图像

测试的时候很粗暴的用0作为阈值进行了二值化处理,就已经取得了相当好的预测结果(准确率目测达到了94%以上,电脑太卡,不想跑代码了),实际使用的话还可以根据预测结果调整到最合适的阈值以提高准确率。k近邻算法的另一个重要参数k也值得推敲,有需要的话查查资料吧。

附上一张运行结果图:

程序运行结果

k近邻算法是最最基础的一个机器学习算法,很明显的缺点是前面的训练(都不能叫训练)没有为后面的预测带来任何福利,算法复杂度相当高,每次都要做一堆的距离运算,非常耗时;另一个缺陷是我们完全没有关注数据的基础结构信息,统一展成了向量,这些问题可以用另一个算法来解决,就是下次要说的决策树。

理解肯定有偏差,欢迎各位大佬指正。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值