机器学习之K近邻法及其python实现

写在前面

本文主要是学习记录贴,参考《统计学习方法》和部分博客完成。如有错误,欢迎积极评论指出。

初步理解

上一篇讲述了感知机模型,当时是用一个超平面将数据集完整分割开来,但是在我们的生活经验中,是否存在其他的思路呢?这时,一个词出现了,那就是——物以类聚,人以群分。比如上面咱们用的一张图:
在这里插入图片描述
两种颜色,代表两个不用的标签,我们很容易发现,相同颜色的样本,通常是以成团的形式聚集在一块的。换句话说,标签相同的产品在特征上肯定是相同或相近的。
那么根据这个思想,我们是不是可以根据我们需要判断的样本点,在数据集中的位置,来判断它的标签类别。事实上,k近邻就是基于这个原理的算法。下面,我们就来梳理一下这个算法。

k近邻法

思想: k近邻法十分简单,就是给定一个训练数据集,对于新的输入实例,在训练数据集中找到与实例最邻近的k个实例,这k个实例的多数属于哪个类,就把该实例分为这个类。

算法: 在这里插入图片描述
三要素: 根据上面的算法,结合下面的图:
在这里插入图片描述
上述样本为训练数据集,现在,我们输入一个新的实例,样本点的位置如下(红色点):
在这里插入图片描述
根据上述的算法流程,首先我们应该找到离红色点最近的k个点,那么第一个要素就出来了——距离度量。也就是我们怎么去衡量两个点的远近程度。书中的定义是这样的:
在这里插入图片描述
当然这里面,我们最熟悉的就是欧氏距离。同时在我们的日常使用中,也是欧式距离最为常用。
那么解决完距离的问题,接下来是k值,假设我们k=1,我们通过计算会发现,距离红色点最近的点是绿色的点,那么我们是不是可以判断,红色点的分类应该和绿色保持一致?那么我们是不是k值越大越好,我们选取k值将所有的训练数据集都包含进去,是不是效果更好?当然不是,k值较少的时候,预测结果会对近邻的实例点十分的敏感,尤其是实例点旁边正好是噪声的时候,就会预测出错。那个当k值选取较大的时候,确实可以降低噪声的影响,等于使模型变得简单,但是相应的也会出现问题,那就是k值选取过大时,可能与输入实例较远(不相似)的点会对预测起作用,使预测发生错误。
实际的使用过程中,环境比较复杂,我们通常是使用交叉验证的方式来选取最优的k值。
最后,我们根据上面流程,选取了红色点附近的k个点,但是我们发现,这k个点并不是同属于一类的,这时,就出现一个问题,我们该怎么在这个类中找到输入实例的分类。也就是分类决策规则问题。根据我们的常识,也是从众心理,谁多我们就听谁的。即由输入实例的k个近邻的训练实例中的多数类决定输入实例的类。

问题

根据描述,我们在实际的使用过程中,每判断一个输入的实例点,都要将所有的训练样本遍历一点。这这这。。也太蠢了吧,针对大规模的样本,我们每次都要全部重新计算一边所有训练样本点与实例点的距离,这是不现实的。不知道你们有没有发现,k近邻法很像我们在算法考试中遇见的一个问题,就是在一个数组中,找到与值A最接近的元素。那么最简单的方法,就是遍历数组中的所有元素,在一一判断。但是这样子明显有点蠢,那么我们想到上课中,老师常用的一个方法——二分法。我们先将数组排序,根据中间的元素与值A的大小,确定接下来应该和哪一部分进行比较。那么,这个思想是不是可以用到k近邻法上?答案是肯定的,也就是kd树的实现。

kd树

kd树的构造

算法:
在这里插入图片描述
在这里插入图片描述
理解相对比较简单,话不多说,上代码:

from collections import namedtuple
import math
import operator

# 用来存放递归结果
result = namedtuple('result_tuple', 'nearest_point dist')

# 定义的二叉树模型
class KD_Node(object):
    def __init__(self, data=None, left=None, right=None):
        self.data = data
        self.left = left
        self.right = right

# 读取数据,将图片读取为785维的向量,其中前784维是28*28,最后是标签
def LoadData(filename):
    data = []
    i = 0
    file = open(filename)
    for line in file.readlines():
        curline = line.strip().split(',')
        data.append([int(int(dt) > 128) for dt in curline[1:]])   #二值化
        data[i].append(int(curline[0]))
        i += 1
    return data

# 这里是我自己尝试的将原图进行特征提取,提取出13维特征向量
# 前8维是将图片分割为8个区域,统计其中数值为1的个数
# 接下来两维将图片横向切割为4部分,分别统计中间两部分的黑色点数目
# 接下来两维是统计纵向切割后中间两部分
# 最后一维是整个图片中黑色点的数目
#将标签放在向量最后,构成了一整个特征向量
def CreatFeature(data):
    data = list(data)
    retdata = []
    for dt in data:
        temp = []
        for i in range(4):
            num1 = 0
            num2 = 0
            for j in range(7):
                for k in range(14):
                    if dt[i*112 + j*14  + k] == 1:
                        num1 += 1
                    if dt[i*112 + j*14  + k + 14] == 1:
                        num2 += 1
            temp.append(num1)
            temp.append(num2)
        temp.append(temp[2] + temp[3])
        temp.append(temp[4] + temp[5])
        num3 = 0
        num4 = 0
        for i in range(28):
            for j in range(7):
                if dt[i*28 + j + 7] == 1:
                    num3 += 1
                if dt[i*28 + j + 14] == 1:
                    num4 += 1
        temp.append(num3)
        temp.append(num4)
        temp.append(temp[0] + temp[1] + temp[2] + temp[3])
        temp.append(dt[-1])
        retdata.append(temp)
    return retdata

# 根据书上的内容,递归实现kd树
# data:训练数据
# depth:深度(0为根节点)
def creat_kd_tree(data, depth):
    if 0 == len(data):
        return None
    split_value = depth % (len(data[0]) - 1) # 减去1的目的是不读取最后一维,也就是标签
    media_value = len(data) // 2
    data.sort(key=lambda x: x[split_value])
    node = KD_Node(data[media_value])
    node.left = creat_kd_tree(data[:media_value], depth + 1)
    node.right = creat_kd_tree(data[media_value+1:], depth + 1)
    return node

# kd树的搜索
# tree:kd树
# depth:搜索深度
# data:预测数据
# delvalue:包含已经找到的前几个最近点,后续搜索将不再搜索该数据
def SearchMin(tree, depth, data, delvalue, max_dist=float('inf')):
    if tree is None:
        return result([0]*len(data), float('inf'))
    split_value = depth % (len(data) - 1) # 搜索维度
    split_axis = tree.data # 轴
    if data[split_value] <= split_axis[split_value]:
        next_tree = tree.left
        other_tree = tree.right     #记录另一结点,方便后续搜寻
    else:
        next_tree = tree.right
        other_tree = tree.left
    temp_result = SearchMin(next_tree, depth+1, data, delvalue, max_dist)   #递归搜索
    nearest = temp_result.nearest_point
    dist = temp_result.dist

    if dist < max_dist:
        max_dist = dist
    point_to_axis = abs(split_axis[split_value] - data[split_value])    #判断超球体是否与轴相交
    if max_dist < point_to_axis:
        return result(nearest, dist)
    temp_dist = math.sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(split_axis, data)))  #计算欧氏距离
    if split_axis not in delvalue:
        if temp_dist < dist:
            nearest = split_axis    #更新最近点和最近距离
            dist = temp_dist
            max_dist = dist
    if other_tree != None:
        temp_other_result = SearchMin(other_tree, depth+1, data, delvalue, max_dist)    #递归搜索另一结点
        if temp_other_result.nearest_point not in delvalue:
            if temp_other_result.dist < dist:
                nearest = temp_other_result.nearest_point
                dist = temp_other_result.dist
    return result(nearest, dist)
    
# 测试准确度
def test(tree, data, k):
    err = 0
    num = 1
    for i in data:
        print('测试数据编号:', num)
        num += 1
        K_value = [0] * 10
        result = []
        for j in range(k):
            temp = SearchMin(tree, 0, i, result)
            result.append(temp.nearest_point)
        for m in result:
            K_value[m[-1]] += 1
        max_index, max_number = max(enumerate(K_value), key=operator.itemgetter(1))
        if max_index != i[-1]:
            err += 1
        print('目前准确率为:', 1 - err / num)
    return (1 - err / len(data))


if __name__ == "__main__":   
    print('read train data')
    trainData = LoadData('Mnist/mnist_train/mnist_train.csv')

    print('read test data')
    testData = LoadData('Mnist/mnist_test/mnist_test.csv')

    print('start extract feature')
    trainDataFeature = CreatFeature(trainData)

    print('start creat tree')
    Tree = creat_kd_tree(trainDataFeature, 0)

    print('start test tree')
    testDataFeature = CreatFeature(testData)
    acc = test(Tree, testDataFeature, 20)
    print(acc)

其中训练数据样本数为60000,测试样本时间较长,我没有全部测试完成,测试100个样本,准确率如下:

测试数据编号: 100
目前准确率为: 0.85

准确率有点低,可能性是特征提取的不好。。我这里只是想这样子尝试一下。
代码部分。。后续有时间我整体整理一下,放在github上。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值