统计学习方法读书笔记6-K近邻算法及代码实现

这篇博客详细介绍了K-近邻算法,包括距离度量、K值选择和分类决策规则。同时,讨论了kd树的构造和搜索,结合P54-57页的内容深入理解kd树的应用。最后,提供了K近邻算法的Python代码实现。
摘要由CSDN通过智能技术生成

1.K-近邻算法

在这里插入图片描述

2.K-近邻模型(三个基本要素)

1.距离度量

在这里插入图片描述
在这里插入图片描述

2.K值的选择

在这里插入图片描述

3.分类决策规则

在这里插入图片描述

3.kd树

通过线性扫描实现k近邻算法,当训练集很大时,计算非常耗时

因此,需要考虑如何对训练数据进行快速的k近邻搜索

为了提高k近邻搜索效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离次数

kd树正是这个方法
1.构造平衡kd树

在这里插入图片描述
在这里插入图片描述
P54-55

2.搜索kd树

在这里插入图片描述
在这里插入图片描述
P56-57
在这里插入图片描述

4.K近邻代码实现

#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author: liujie
@software: PyCharm
@file: KNN.py
@time: 2020/10/20 22:25
"""
# KNN没有显示的训练过程
import time
import numpy as np
from tqdm import tqdm


def loaddata(filename):
    """
    加载数据
    :param filename:文件路径
    :return: 返回数据集与标签
    """
    print('start to read file')
    # 存放数据与标签
    dataArr = []
    labelArr = []
    # 打开文件
    fr = open(filename)
    # 循环读取文件每一行
    for line in tqdm(fr.readlines()):
        # 获取当前行,并存放入列表中
        # strip:去掉每行字符串首尾指定的字符(默认空格或换行符)
        # split:按照指定的字符将字符串切割成每个字段,返回列表形式
        currentLine = line.strip().split(',')
        # 存放数据并转换成整型
        dataArr.append([int(num) for num in currentLine[1:]])
        # 存放标签并转换成整型
        labelArr.append(int(currentLine[0]))

    return dataArr, labelArr

def calDist(x1,x2):
    """
    计算欧式距离
    :param x1: 向量1
    :param x2: 向量2
    :return: 欧式距离
    """
    # 欧式距离
    return np.sqrt(np.sum(np.square(x1 - x2)))
    # 马哈顿距离
    # return np.sum(x1 - x2)

def getClosest(trainDataMat, trainLabelMat, x, topK):
    """
    预测x的标记
    多数表决
    :param trainDataMat:训练数据集
    :param trainLabelMat: 训练数据标签
    :param x: 预测样本x
    :param topK: 选择参考最邻近样本的数目
    :return: 预测的标记
    """

    # 建立一个存放向量x与每个训练集中样本距离的字典
    distDict = {}
    # 遍历训练集中所有样本点,计算与x的距离
    for i in tqdm(range(len(trainDataMat))):
        # 获取向量
        xi = trainDataMat[i]
        # 计算距离
        curDist = calDist(x,xi)
        # 将距离放入对应的字典位置
        distDict[i] = curDist

    # 对字典的value进行排序-升序排序-字典无法排序,但是可以建立有序的数据类型来表示排序后的值,该值将是一个列表-可能是一个元组列表
    # argsort:函数将数组的值从小到大排序后,并按照其相对应的索引值输出--列表
    # sorted:sorted(iterable,key,reverse),sorted一共有iterable,key,reverse这三个参数
    dist_list_topK = sorted(distDict.items(),key = lambda x:x[1],reverse=False)[:topK]
    # 转变成字典
    dist_dict_topk = dict(dist_list_topK)
    # print(dist_dict_topk)
    # dist_dict_topK中key进行循环
    labelLict = [0] * 10
    for index in dist_dict_topk:
        # 找到最近topk中的标签,并进行多数表决投票
        labelLict[int(trainLabelMat[index])] += 1

    # 找到选票箱中票数最多的票数值
    return np.argsort(np.array(labelLict))[-1]

# 定义计算正确率的函数
def model_test(trainData,trainLabel,testData,testLabel,topK):
    """
    测试正确率
    :param trainData:训练数据集
    :param trainLabel: 训练标签
    :param testData: 测试数据集
    :param testLabel: 测试标签
    :param topK: 选择多少个临近点参考
    :return: 正确率
    """
    print('start to test')
    # 将列表转化为矩阵,方便并行运算
    trainDataMat = np.mat(trainData)
    trainLabelMat = np.mat(trainLabel).T
    testDataMat = np.mat(testData)
    testLabelMat = np.mat(testLabel).T

    # 错误计数
    errorCnt = 0
    # 遍历测试集,对每个测试集样本进行测试
    # 由于计算向量与向量之间的时间耗费太大,测试集中每个样本都要计算与60000个样本的距离,所以这里人为改成了200个
    # for i in range(len(testDataMat)):
    #    print('test %d : %d'%(i,len(testDataMat)))
    for i in range(200):
        print('test %d : %d'%(i,200))
        # 获取测试向量与标签
        x = testDataMat[i]
        y = getClosest(trainDataMat,trainLabelMat,x,topK)

        # 预测标记与实际标记不符,错误计数加1
        if y != testLabelMat[i]:errorCnt += 1

    return 1 - errorCnt / 200




if __name__ == '__main__':
    start = time.time()
    # 加载数据集
    trainData, trainLabel = loaddata('data/mnist_train.csv')
    testData, testLabel = loaddata('data/mnist_test.csv')

    # 计算测试集的正确率
    accur = model_test(trainData,trainLabel,testData,testLabel,25)
    # 打印正确率
    print('accur : %d'%(accur*100),'%')

    end = time.time()
    print('time span:',end-start)
    
    
accur : 97 %
time span: 307.68515515327454
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值